Tensorflow Object-Detection API - How does the Fine-Tuning of a model works?

后端 未结 1 1922
感情败类
感情败类 2020-12-29 16:50

This is a more general question about the Tensorflow Object-Detection API.

I am using this API, to be more concrete I fine-tune a model to my dataset. According to

相关标签:
1条回答
  • 2020-12-29 17:15

    Training from stratch or training from a checkpoint, model_main.py is the main program, besides this program, all you need is a correct pipeline config file.

    So for fine-tuning, it can be separated into two steps, restoring weights and updating weights. Both steps can be customly configured according to the train proto file, this proto corresponds to train_config in the pipeline config file.

    train_config: {
       batch_size: 24
       optimizer { }
       fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
       fine_tune_checkpoint_type:  "detection"
       # Note: The below line limits the training process to 200K steps, which we
       # empirically found to be sufficient enough to train the pets dataset. This
       # effectively bypasses the learning rate schedule (the learning rate will
       # never decay). Remove the below line to train indefinitely.
       num_steps: 200000
       data_augmentation_options {}
     }
    

    Step 1, restoring weights.

    In this step, you can config the variables to be restored by setting fine_tune_checkpoint_type, the options are detection and classification. By setting it to detection essentially you can restore almost all variables from the checkpoint, and by setting it to classification, only variables from the feature_extractor scope are restored, (all the layers in backbone networks, like VGG, Resnet, MobileNet, they are called feature extractors).

    Previously this is controlled by from_detection_checkpoint and load_all_detection_checkpoint_vars, but these two fields are deprecated.

    Also notice that after you configured the fine_tune_checkpoint_type, the actual restoring operation would check if the variable in the graph exists in the checkpoint, and if not, the variable would be initialized with routine initialization operation.

    Give an example, suppose you want to fine-tune a ssd_mobilenet_v1_custom_data model and you downloaded the checkpoint ssd_mobilenet_v1_coco, when you set fine_tune_checkpoint_type: detection, then all variables in the graph that are also available in the checkpoint file will be restored, and the box predictor (last layer) weights will also be restored. But if you set fine_tune_checkpoint_type: classification, then only the weights for mobilenet layers are restored. But if you use a different model checkpoint, say faster_rcnn_resnet_xxx, then because variables in the graph are not available in the checkpoint, you will see the output log saying Variable XXX is not available in checkpoint warning, and they won't be restored.

    Step 2, updating weights

    Now you have all weights restored and you want to keep training (fine-tuning) on your own dataset, normally this should be enough.

    But if you want to experiment with something and you want to freeze some layers during training, then you can customize the training by setting freeze_variables. Say you want to freeze all the weights of the mobilenet and only updating the weights for the box predictor, you can set freeze_variables: [feature_extractor] so that all variables that have feature_extractor in their names won't be updated. For detailed info, please see another answer that I wrote.

    So to fine-tune a model on your custom dataset, you should prepare a custom config file. You can start with the sample config files and then modify some fields to suit your needs.

    0 讨论(0)
提交回复
热议问题