Saving and Restoring a trained LSTM in Tensor Flow

三世轮回 提交于 2019-12-12 10:40:02

问题


I trained a LSTM classifier, using a BasicLSTMCell. How can I save my model and restore it for use in later classifications?


回答1:


I was wondering this myself. As other pointed out, the usual way to save a model in TensorFlow is to use tf.train.Saver(), however I believe this saves the values of tf.Variables. I'm not exactly sure if there are tf.Variables inside the BasicLSTMCell implementation which are saved automatically when you do this, or if there is perhaps another step that need to be taken, but if all else fails, the BasicLSTMCell can be easily saved and loaded in a pickle file.




回答2:


We found the same issue. We weren't sure if the internal variables were saved. We found out that you must create the saver after the BasicLSTMCell is created /defined. Otherewise it is not saved.




回答3:


The easiest way to save and restore a model is to use a tf.train.Saverobject. The constructor adds save and restore ops to the graph for all, or a specified list, of the variables in the graph. The saver object provides methods to run these ops, specifying paths for the checkpoint files to write to or read from.

Refer to:

https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html

Checkpoint Files

Variables are saved in binary files that, roughly, contain a map from variable names to tensor values.

When you create a Saver object, you can optionally choose names for the variables in the checkpoint files. By default, it uses the value of the Variable.name property for each variable.

To understand what variables are in a checkpoint, you can use the inspect_checkpoint library, and in particular, the print_tensors_in_checkpoint_file function.

Saving Variables

Create a Saver with tf.train.Saver() to manage all variables in the model.

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

Restoring Variables

The same Saver object is used to restore variables. Note that when you restore variables from a file you do not have to initialize them beforehand.

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...



回答4:


Yes, there are weight and bias variables inside the LSTM cell (indeed, all neural network cells have to have weight vars somewhere). as already noted in other answers, using the Saver object appears to be the way to go... saves your variables and your (meta)graph in a reasonably convenient way. You'll need the metagraph if you want to get the whole model back, not just some tf.Variables sitting there in isolation. It does need to know all the variables it has to save, so create the saver after creating the graph.

A useful little trick when dealing with any "is there variables?"/"is it properly reusing weights?"/"how can I actually look at the weights in my LSTM, which isn't bound to any python var?"/etc. situation is this little snippet:

for i in tf.global_variables():
    print(i)

for vars and

for i in my_graph.get_operations():
    print (i)

for ops. If you want to view a tensor that isn't bound to a python var,

tf.Graph.get_tensor_by_name('name_of_op:N')

where name of op is the name of the operation that generates the tensor, and N is an index of which (of possibly several) output tensors you're after.

tensorboard's graph display can be helpful for finding op names if your graph has a ton of operations...which most tend to...




回答5:


I've made example code for LSTM save and restore. I also took a lot of time to solve this. Refer to this url : https://github.com/MareArts/rnn_save_restore_test I hope to help this code.




回答6:


You can instantiate a tf.train.Saver object and call save passing the current session and output checkpoint file (*.ckpt) path during training. You can call save whenever you think is appropriate (e.g. every few epochs, when validation error drops):

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

During classification/inference you instantiate another tf.train.Saver and call restore passing the current session and the checkpoint file to restore. You can call restore just before you use your model for classification by calling session.run:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...

Reference: https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html#saving-and-restoring



来源:https://stackoverflow.com/questions/40442098/saving-and-restoring-a-trained-lstm-in-tensor-flow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!