Transfer learning with tf.estimator.Estimator framework

前端 未结 2 513
独厮守ぢ
独厮守ぢ 2021-02-04 05:04

I\'m trying to do transfer learning of an Inception-resnet v2 model pretrained on imagenet, using my own dataset and classes. My original codebase was a modification of a

2条回答
  •  粉色の甜心
    2021-02-04 05:37

    Thanks to @KathyWu's comment, I got on the right track and found the problem.

    Indeed, the way I was computing the scopes would include the InceptionResnetV2/ scope, that would trigger the load of all variables "under" the scope (i.e., all variables in the network). Replacing this with the correct dictionary, however, was not trivial.

    Of the possible scope modes init_from_checkpoint accepts, the one I had to use was the 'scope_variable_name': variable one, but without using the actual variable.name attribute.

    The variable.name looks like: 'some_scope/variable_name:0'. That :0 is not in the checkpointed variable's name and so using scopes = {v.name:v.name for v in variables_to_restore} will raise a "Variable not found" error.

    The trick to make it work was stripping the tensor index from the name:

    tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', 
                                  {v.name.split(':')[0]: v for v in variables_to_restore})
    

提交回复
热议问题