I am trying to save the latest LSTM State from training to be reused during the prediction stage later. The problem I am encountering is that in the TF LSTM model the State
The issue is that creating a new tf.Variable
after the Saver
was constructed means that the Saver
has no knowledge of the new variable. It still gets saved in the metagraph, but not saved in the checkpoint:
import tensorflow as tf
with tf.Graph().as_default():
var_a = tf.get_variable("a", shape=[])
saver = tf.train.Saver()
var_b = tf.get_variable("b", shape=[])
print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
initializer = tf.global_variables_initializer()
with tf.Session() as session:
session.run([initializer])
saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
with tf.Session() as session:
new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!
I've annotated the quick reproduction of your issue above with the variables that the Saver
knows about.
Now, the solution is relatively easy. I would suggest creating the Variable
before the Saver
, then using tf.assign to update its value (make sure you run the op returned by tf.assign
). The assigned value will be saved in checkpoints and restored just like other variables.
This could be handled better by the Saver
as a special case when None
is passed to its var_list
constructor argument (i.e. it could pick up new variables automatically). Feel free to open a feature request on Github for this.