Tensorflow: how to save/restore a model?

前端 未结 26 2582
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  悲&欢浪女
    2020-11-21 12:20

    All the answers here are great, but I want to add two things.

    First, to elaborate on @user7505159's answer, the "./" can be important to add to the beginning of the file name that you are restoring.

    For example, you can save a graph with no "./" in the file name like so:

    # Some graph defined up here with specific names
    
    saver = tf.train.Saver()
    save_file = 'model.ckpt'
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, save_file)
    

    But in order to restore the graph, you may need to prepend a "./" to the file_name:

    # Same graph defined up here
    
    saver = tf.train.Saver()
    save_file = './' + 'model.ckpt' # String addition used for emphasis
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, save_file)
    

    You will not always need the "./", but it can cause problems depending on your environment and version of TensorFlow.

    It also want to mention that the sess.run(tf.global_variables_initializer()) can be important before restoring the session.

    If you are receiving an error regarding uninitialized variables when trying to restore a saved session, make sure you include sess.run(tf.global_variables_initializer()) before the saver.restore(sess, save_file) line. It can save you a headache.

提交回复
热议问题