Tensorflow: how to save/restore a model?

前端 未结 26 2544
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
相关标签:
26条回答
  • 2020-11-21 12:12

    For TensorFlow version < 0.11.0RC1:

    The checkpoints that are saved contain values for the Variables in your model, not the model/graph itself, which means that the graph should be the same when you restore the checkpoint.

    Here's an example for a linear regression where there's a training loop that saves variable checkpoints and an evaluation section that will restore variables saved in a prior run and compute predictions. Of course, you can also restore variables and continue training if you'd like.

    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    
    w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
    b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
    y_hat = tf.add(b, tf.matmul(x, w))
    
    ...more setup for optimization and what not...
    
    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if FLAGS.train:
            for i in xrange(FLAGS.training_steps):
                ...training loop...
                if (i + 1) % FLAGS.checkpoint_steps == 0:
                    saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                               global_step=i+1)
        else:
            # Here's where you're restoring the variables w and b.
            # Note that the graph is exactly as it was when the variables were
            # saved in a prior training run.
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                ...no checkpoint found...
    
            # Now you can run the model to get predictions
            batch_x = ...load some data...
            predictions = sess.run(y_hat, feed_dict={x: batch_x})
    

    Here are the docs for Variables, which cover saving and restoring. And here are the docs for the Saver.

    0 讨论(0)
  • 2020-11-21 12:12

    If you use tf.train.MonitoredTrainingSession as the default session, you don't need to add extra code to do save/restore things. Just pass a checkpoint dir name to MonitoredTrainingSession's constructor, it will use session hooks to handle these.

    0 讨论(0)
  • 2020-11-21 12:14

    As described in issue 6255:

    use '**./**model_name.ckpt'
    saver.restore(sess,'./my_model_final.ckpt')
    

    instead of

    saver.restore('my_model_final.ckpt')
    
    0 讨论(0)
  • 2020-11-21 12:15

    My environment: Python 3.6, Tensorflow 1.3.0

    Though there have been many solutions, most of them is based on tf.train.Saver. When we load a .ckpt saved by Saver, we have to either redefine the tensorflow network or use some weird and hard-remembered name, e.g. 'placehold_0:0','dense/Adam/Weight:0'. Here I recommend to use tf.saved_model, one simplest example given below, your can learn more from Serving a TensorFlow Model:

    Save the model:

    import tensorflow as tf
    
    # define the tensorflow network and do some trains
    x = tf.placeholder("float", name="x")
    w = tf.Variable(2.0, name="w")
    b = tf.Variable(0.0, name="bias")
    
    h = tf.multiply(x, w)
    y = tf.add(h, b, name="y")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    # save the model
    export_path =  './savedmodel'
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    
    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    
    prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'x_input': tensor_info_x},
          outputs={'y_output': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
    
    builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              prediction_signature 
      },
      )
    builder.save()
    

    Load the model:

    import tensorflow as tf
    sess=tf.Session() 
    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    input_key = 'x_input'
    output_key = 'y_output'
    
    export_path =  './savedmodel'
    meta_graph_def = tf.saved_model.loader.load(
               sess,
              [tf.saved_model.tag_constants.SERVING],
              export_path)
    signature = meta_graph_def.signature_def
    
    x_tensor_name = signature[signature_key].inputs[input_key].name
    y_tensor_name = signature[signature_key].outputs[output_key].name
    
    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)
    
    y_out = sess.run(y, {x: 3.0})
    
    0 讨论(0)
  • 2020-11-21 12:18

    You can also check out examples in TensorFlow/skflow, which offers save and restore methods that can help you easily manage your models. It has parameters that you can also control how frequently you want to back up your model.

    0 讨论(0)
  • 2020-11-21 12:19

    Here's my simple solution for the two basic cases differing on whether you want to load the graph from file or build it during runtime.

    This answer holds for Tensorflow 0.12+ (including 1.0).

    Rebuilding the graph in code

    Saving

    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')
    

    Loading

    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        # now you can use the graph, continue training or whatever
    

    Loading also the graph from a file

    When using this technique, make sure all your layers/variables have explicitly set unique names. Otherwise Tensorflow will make the names unique itself and they'll be thus different from the names stored in the file. It's not a problem in the previous technique, because the names are "mangled" the same way in both loading and saving.

    Saving

    graph = ... # build the graph
    
    for op in [ ... ]:  # operators you want to use after restoring the model
        tf.add_to_collection('ops_to_restore', op)
    
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')
    

    Loading

    with ... as sess:  # your session object
        saver = tf.train.import_meta_graph('my-model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection
    
    0 讨论(0)
提交回复
热议问题