TF LSTM: Save State from training session for prediction session later

后端 未结 1 1944
梦如初夏
梦如初夏 2021-01-16 16:09

I am trying to save the latest LSTM State from training to be reused during the prediction stage later. The problem I am encountering is that in the TF LSTM model the State

相关标签:
1条回答
  • 2021-01-16 16:31

    The issue is that creating a new tf.Variable after the Saver was constructed means that the Saver has no knowledge of the new variable. It still gets saved in the metagraph, but not saved in the checkpoint:

    import tensorflow as tf
    with tf.Graph().as_default():
      var_a = tf.get_variable("a", shape=[])
      saver = tf.train.Saver()
      var_b = tf.get_variable("b", shape=[])
      print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
      initializer = tf.global_variables_initializer()
      with tf.Session() as session:
        session.run([initializer])
        saver.save(session, "/tmp/model", global_step=0)
    with tf.Graph().as_default():
      new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
      print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
      with tf.Session() as session:
        new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!
    

    I've annotated the quick reproduction of your issue above with the variables that the Saver knows about.

    Now, the solution is relatively easy. I would suggest creating the Variable before the Saver, then using tf.assign to update its value (make sure you run the op returned by tf.assign). The assigned value will be saved in checkpoints and restored just like other variables.

    This could be handled better by the Saver as a special case when None is passed to its var_list constructor argument (i.e. it could pick up new variables automatically). Feel free to open a feature request on Github for this.

    0 讨论(0)
提交回复
热议问题