Restore variables that are a subset of new model in Tensorflow?

前端 未结 1 1945
鱼传尺愫
鱼传尺愫 2021-02-03 10:38

I am doing an example for boosting(4 layers DNN to 5 layers DNN) via Tensorflow. I am making it with save session and restore in TF because there is a brief paragraph in TF tute

1条回答
  •  逝去的感伤
    2021-02-03 10:56

    It doesn't look right to read values for boosting from the checkpoint in this case and I think that's not what you want to do. Obviously you're getting error, since while restoring the variables you are first catching the list of all of the variables in your model and then you look for corresponding variables in your checkpoint, which doesn't have them.

    You can restore only part of your model by defining a subset of your model variables. For example you can do it using tf.slim library. Getting the list of variables in your models:

    variables = slim.get_variables_to_restore()
    

    Now variables is a list of tensors, but for each element you can access its name attribute. Using that you can specify that you only want to restore layers other than boosting, e.g.:

    variables_to_restore = [v for v in variables if v.name.split('/')[0]!='boosting'] 
    model_path = 'your/model/path'
    
    saver = tf.train.Saver(variables_to_restore)
    
    with tf.Session() as sess:
        saver.restore(sess, model_path)
    

    This way you will have your 4 layers restored. Theoretically you could try to catch values of one of your variables from checkpoint by creating another server that will only have boosting in variables list and renaming the chosen variable from the checkpoint, but I really don't think it's what you need here.

    Since this is a custom layer for your model and you don't have this variable anywhere, just initialize it within your workflow instead of trying to import it. You can do for example by passing this argument while calling a function fully_connected:

    weights_initializer = slim.variance_scaling_initializer()
    

    You need to check details yourself though, since I'm not sure what your imports are and which function are you using here.

    Generally I'd advice you to take a look at slim library, which will make it easier for you to define a model and scopes for layers (instead of defining it by with you can rather pass a scope argument while calling a function). It would look something like that with slim:

    boost = slim.fully_connected(input, number_of_outputs, activation_fn=None, scope='boosting', weights_initializer=slim.variance_scaling_initializer())
    

    0 讨论(0)
提交回复
热议问题