问题
I am using tensorflow object-detection api for training a custom model using ssdlite_mobilenet_v2_coco_2018_05_09 from tensorflow model zoo.
I successfully trained the model and test it out using a script provided in this tutorial.
Here is the problem, I need a detect.tflite to use it in my target machine (an embedded system). But when I actually make a tflite out of my model, it outputs almost nothing and when it does, its a wrong detection. To make the .tflite file, I first used export_tflite_ssd_graph.py
and then toco
on the output with this command by following the doc and some google searches:
toco --graph_def_file=$OUTPUT_DIR/tflite_graph.pb --output_file=$OUTPUT_DIR/detect.tflite --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --allow_custom_ops
Also, the code I'm using for detection task from .tflite is working properly, as I tested it with ssd_mobilenet_v3_small_coco detect.tflite file.
回答1:
The problem was with the toco
command. Some documents that I used were outdated and mislead me. toco
is deprecated and I should have used tflite_convert
tool instead.
Here is the full command I used (run from your training directory):
tflite_convert --graph_def_file tflite_inference_graph/tflite_graph.pb --output_file=./detect.tflite --output_format=TFLITE --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --inference_type=QUANTIZED_UINT8 --mean_values=128 --std_dev_values=127 --change_concat_input_ranges=false --allow_custom_ops
I did the training on ssdlite_mobilenet_v2_coco_2018_05_09 model and added this at the end of my .config file.
graph_rewriter {
quantization {
delay: 400
weight_bits: 8
activation_bits: 8
}
}
Also I used this command to generate tflite_graph.pb in tflite_inference_graph directory:
python export_tflite_ssd_graph.py --pipeline_config_path 2020-05-17_train_ssdlite_v2/ssd_mobilenet_v2_coco.config --trained_checkpoint_prefix 2020-05-17_train_ssdlite_v2/train/model.ckpt-1146 --output_directory 2020-05-17_train_ssdlite_v2/tflite_inference_graph --add_postprocessing_op=true
Note: I wanted to use a quantized model on my embedded system. That is the reason I added graph_rewriter in the config file and --inference_type=QUANTIZED_UINT8 in my tflite_convert command.
来源:https://stackoverflow.com/questions/61749548/how-to-convert-tflite-graph-pb-to-detect-tflite-properly