Tensorflow: how to save/restore a model?

前端 未结 26 2414
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
相关标签:
26条回答
  • 2020-11-21 12:04

    Following @Vishnuvardhan Janapati 's answer, here is another way to save and reload model with custom layer/metric/loss under TensorFlow 2.0.0

    import tensorflow as tf
    from tensorflow.keras.layers import Layer
    from tensorflow.keras.utils.generic_utils import get_custom_objects
    
    # custom loss (for example)  
    def custom_loss(y_true,y_pred):
      return tf.reduce_mean(y_true - y_pred)
    get_custom_objects().update({'custom_loss': custom_loss}) 
    
    # custom loss (for example) 
    class CustomLayer(Layer):
      def __init__(self, ...):
          ...
      # define custom layer and all necessary custom operations inside custom layer
    
    get_custom_objects().update({'CustomLayer': CustomLayer})  
    

    In this way, once you have executed such codes, and saved your model with tf.keras.models.save_model or model.save or ModelCheckpoint callback, you can re-load your model without the need of precise custom objects, as simple as

    new_model = tf.keras.models.load_model("./model.h5"})
    
    0 讨论(0)
  • 2020-11-21 12:05

    Use tf.train.Saver to save a model, remerber, you need to specify the var_list, if you want to reduce the model size. The val_list can be tf.trainable_variables or tf.global_variables.

    0 讨论(0)
  • 2020-11-21 12:05

    For tensorflow-2.0

    it's very simple.

    import tensorflow as tf
    

    SAVE

    model.save("model_name")
    

    RESTORE

    model = tf.keras.models.load_model('model_name')
    
    0 讨论(0)
  • 2020-11-21 12:07

    For tensorflow 2.0, it is as simple as

    # Save the model
    model.save('path_to_my_model.h5')
    

    To restore:

    new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')
    
    0 讨论(0)
  • 2020-11-21 12:08

    In most cases, saving and restoring from disk using a tf.train.Saver is your best option:

    ... # build your model
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        ... # train the model
        saver.save(sess, "/tmp/my_great_model")
    
    with tf.Session() as sess:
        saver.restore(sess, "/tmp/my_great_model")
        ... # use the model
    

    You can also save/restore the graph structure itself (see the MetaGraph documentation for details). By default, the Saver saves the graph structure into a .meta file. You can call import_meta_graph() to restore it. It restores the graph structure and returns a Saver that you can use to restore the model's state:

    saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")
    
    with tf.Session() as sess:
        saver.restore(sess, "/tmp/my_great_model")
        ... # use the model
    

    However, there are cases where you need something much faster. For example, if you implement early stopping, you want to save checkpoints every time the model improves during training (as measured on the validation set), then if there is no progress for some time, you want to roll back to the best model. If you save the model to disk every time it improves, it will tremendously slow down training. The trick is to save the variable states to memory, then just restore them later:

    ... # build your model
    
    # get a handle on the graph nodes we need to save/restore the model
    graph = tf.get_default_graph()
    gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
    init_values = [assign_op.inputs[1] for assign_op in assign_ops]
    
    with tf.Session() as sess:
        ... # train the model
    
        # when needed, save the model state to memory
        gvars_state = sess.run(gvars)
    
        # when needed, restore the model state
        feed_dict = {init_value: val
                     for init_value, val in zip(init_values, gvars_state)}
        sess.run(assign_ops, feed_dict=feed_dict)
    

    A quick explanation: when you create a variable X, TensorFlow automatically creates an assignment operation X/Assign to set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dict and replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.

    0 讨论(0)
  • 2020-11-21 12:11

    You can save the variables in the network using

    saver = tf.train.Saver() 
    saver.save(sess, 'path of save/fileName.ckpt')
    

    To restore the network for reuse later or in another script, use:

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint('path of save/')
    sess.run(....) 
    

    Important points:

    1. sess must be same between first and later runs (coherent structure).
    2. saver.restore needs the path of the folder of the saved files, not an individual file path.
    0 讨论(0)
提交回复
热议问题