Tensorflow: how to save/restore a model?

前端 未结 26 2555
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  长情又很酷
    2020-11-21 12:12

    For TensorFlow version < 0.11.0RC1:

    The checkpoints that are saved contain values for the Variables in your model, not the model/graph itself, which means that the graph should be the same when you restore the checkpoint.

    Here's an example for a linear regression where there's a training loop that saves variable checkpoints and an evaluation section that will restore variables saved in a prior run and compute predictions. Of course, you can also restore variables and continue training if you'd like.

    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    
    w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
    b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
    y_hat = tf.add(b, tf.matmul(x, w))
    
    ...more setup for optimization and what not...
    
    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if FLAGS.train:
            for i in xrange(FLAGS.training_steps):
                ...training loop...
                if (i + 1) % FLAGS.checkpoint_steps == 0:
                    saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                               global_step=i+1)
        else:
            # Here's where you're restoring the variables w and b.
            # Note that the graph is exactly as it was when the variables were
            # saved in a prior training run.
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                ...no checkpoint found...
    
            # Now you can run the model to get predictions
            batch_x = ...load some data...
            predictions = sess.run(y_hat, feed_dict={x: batch_x})
    

    Here are the docs for Variables, which cover saving and restoring. And here are the docs for the Saver.

提交回复
热议问题