How to restore my loss from a saved meta graph?

匿名 (未验证) 提交于 2019-12-03 07:50:05

问题:

I have built a simple tensorflow model that is working fine. While training I save the meta_graph and also some parameters at different steps.

After that (in a new script) I want to restore the saved meta_graph and restore variables and operations.

Everything works fine, but only the

with tf.name_scope('MSE'):     error = tf.losses.mean_squared_error(Y, yhat, scope="error") 

is not going to be restored. With the following line

mse_error = graph.get_tensor_by_name("MSE/error:0") 

"The name 'MSE/error:0' refers to a Tensor which does not exist. The operation, 'MSE/error', does not exist in the graph."

there appears this error message.

As I do exactly the same procedure for other variables and ops that are restored without any error, I don't know how to deal with that. Only difference is that there is only a scope attribute and not a name attribute in the tf.losses.mean_squared_error function.

So how do I restore the loss operation with the scope?

Here the code how I save and load the model.

Saving:

# define network ... saver = tf.train.Saver(max_to_keep=10)  sess = tf.Session() sess.run(tf.global_variables_initializer())  for i in range(NUM_EPOCHS):     # do training ..., save model all 1000 optimization steps     if (i + 1) % 1000 == 0:         saver.save(sess, "L:/model/mlp_model", global_step=(i+1)) 

Restore:

# start a session sess=tf.Session() # load meta graph saver = tf.train.import_meta_graph('L:\\model\\mlp_model-1000.meta') # restore weights saver.restore(sess, tf.train.latest_checkpoint('L:\\model\\'))  # access network nodes graph = tf.get_default_graph() X = graph.get_tensor_by_name("Input/X:0") Y = graph.get_tensor_by_name("Input/Y:0")  # restore output-generating operation used for prediction yhat_op = graph.get_tensor_by_name("OutputLayer/yhat:0") mse_error = graph.get_tensor_by_name("MSE/error:0") # this one doesn't work 

回答1:

To get your training step back, the documentation suggests you add it to a collection before saving it as a way to be able to point at it to after restoring your graph.

Saving:

saver = tf.train.Saver(max_to_keep=10) # put op in collection tf.add_to_collection('train_op', train_op) ... 

Restore:

saver = tf.train.import_meta_graph('L:\\model\\mlp_model-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('L:\\model\\')) # recover op through collection train_op = tf.get_collection('train_op')[0] 

Why did your attempt at recovering the tensor by name fail?

You can indeed get the tensor by its name -- the catch is that you need the correct name. And notice that your error argument to tf.losses.mean_squared_error is a scope name, not the name of the returned operation. This can be confusing, as other operations, such as tf.nn.l2_loss, accept a name argument.

In the end, the name of your error operation is MSE/error/value:0, which you can use to get it by name.

That is, until it breaks again in the future when you update tensorflow. tf.losses.mean_squared_error does not give you any guarantee on the name of its output, so it very well may change for some reason.

I think this is what motivates the use of collections: the lack of guarantee on the names of the operators you don't control yourself.

Alternatively, if for some reason you really want to use names, you could rename your operator like this:

with tf.name_scope('MSE'):   error = tf.losses.mean_squared_error(Y, yhat, scope='error')   # let me stick my own name on it   error = tf.identity(error, 'my_error') 

Then you can rely on graph.get_tensor_by_name('MSE/my_error:0') safely.



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!