Tensorflow: how to save/restore a model?

前端 未结 26 2593
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  一整个雨季
    2020-11-21 12:03

    Wherever you want to save the model,

    self.saver = tf.train.Saver()
    with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                ...
                self.saver.save(sess, filename)
    

    Make sure, all your tf.Variable have names, because you may want to restore them later using their names. And where you want to predict,

    saver = tf.train.import_meta_graph(filename)
    name = 'name given when you saved the file' 
    with tf.Session() as sess:
          saver.restore(sess, name)
          print(sess.run('W1:0')) #example to retrieve by variable name
    

    Make sure that saver runs inside the corresponding session. Remember that, if you use the tf.train.latest_checkpoint('./'), then only the latest check point will be used.

提交回复
热议问题