Tensorflow: how to save/restore a model?

前端 未结 26 2576
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  逝去的感伤
    2020-11-21 12:23

    In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to https://www.tensorflow.org/programmers_guide/meta_graph.

    Save the model

    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
    w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
    tf.add_to_collection('vars', w1)
    tf.add_to_collection('vars', w2)
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my-model')
    # `save` method will call `export_meta_graph` implicitly.
    # you will get saved graph files:my-model.meta
    

    Restore the model

    sess = tf.Session()
    new_saver = tf.train.import_meta_graph('my-model.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
        v_ = sess.run(v)
        print(v_)
    

提交回复
热议问题