Transfer learning with tf.estimator.Estimator framework

前端 未结 2 512
独厮守ぢ
独厮守ぢ 2021-02-04 05:04

I\'m trying to do transfer learning of an Inception-resnet v2 model pretrained on imagenet, using my own dataset and classes. My original codebase was a modification of a

相关标签:
2条回答
  • 2021-02-04 05:37

    Thanks to @KathyWu's comment, I got on the right track and found the problem.

    Indeed, the way I was computing the scopes would include the InceptionResnetV2/ scope, that would trigger the load of all variables "under" the scope (i.e., all variables in the network). Replacing this with the correct dictionary, however, was not trivial.

    Of the possible scope modes init_from_checkpoint accepts, the one I had to use was the 'scope_variable_name': variable one, but without using the actual variable.name attribute.

    The variable.name looks like: 'some_scope/variable_name:0'. That :0 is not in the checkpointed variable's name and so using scopes = {v.name:v.name for v in variables_to_restore} will raise a "Variable not found" error.

    The trick to make it work was stripping the tensor index from the name:

    tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', 
                                  {v.name.split(':')[0]: v for v in variables_to_restore})
    
    0 讨论(0)
  • 2021-02-04 05:54

    I find out {s+'/':s+'/' for s in scopes} didn't work, just because the variables_to_restore include something like "global_step", so scopes include the global scopes which could include everything. You need to print variables_to_restore, find "global_step" thing, and put it in "exclude".

    0 讨论(0)
提交回复
热议问题