Tensorflow: how to save/restore a model?

前端 未结 26 2540
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
相关标签:
26条回答
  • 2020-11-21 12:00

    I am improving my answer to add more details for saving and restoring models.

    In(and after) Tensorflow version 0.11:

    Save the model:

    import tensorflow as tf
    
    #Prepare to feed input, i.e. feed_dict and placeholders
    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}
    
    #Define a test operation that we will restore
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    #Create a saver object which will save all the variables
    saver = tf.train.Saver()
    
    #Run the operation by feeding input
    print sess.run(w4,feed_dict)
    #Prints 24 which is sum of (w1+w2)*b1 
    
    #Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)
    

    Restore the model:

    import tensorflow as tf
    
    sess=tf.Session()    
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    
    
    # Access saved Variables directly
    print(sess.run('bias:0'))
    # This will print 2, which is the value of bias that we saved
    
    
    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data
    
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}
    
    #Now, access the op that you want to run. 
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    
    print sess.run(op_to_restore,feed_dict)
    #This will print 60 which is calculated 
    

    This and some more advanced use-cases have been explained very well here.

    A quick complete tutorial to save and restore Tensorflow models

    0 讨论(0)
  • 2020-11-21 12:01

    tf.keras Model saving with TF2.0

    I see great answers for saving models using TF1.x. I want to provide couple of more pointers in saving tensorflow.keras models which is a little complicated as there are many ways to save a model.

    Here I am providing an example of saving a tensorflow.keras model to model_path folder under current directory. This works well with most recent tensorflow (TF2.0). I will update this description if there is any change in near future.

    Saving and loading entire model

    import tensorflow as tf
    from tensorflow import keras
    mnist = tf.keras.datasets.mnist
    
    #import data
    (x_train, y_train),(x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    # create a model
    def create_model():
      model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(512, activation=tf.nn.relu),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation=tf.nn.softmax)
        ])
    # compile the model
      model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
      return model
    
    # Create a basic model instance
    model=create_model()
    
    model.fit(x_train, y_train, epochs=1)
    loss, acc = model.evaluate(x_test, y_test,verbose=1)
    print("Original model, accuracy: {:5.2f}%".format(100*acc))
    
    # Save entire model to a HDF5 file
    model.save('./model_path/my_model.h5')
    
    # Recreate the exact same model, including weights and optimizer.
    new_model = keras.models.load_model('./model_path/my_model.h5')
    loss, acc = new_model.evaluate(x_test, y_test)
    print("Restored model, accuracy: {:5.2f}%".format(100*acc))
    

    Saving and loading model Weights only

    If you are interested in saving model weights only and then load weights to restore the model, then

    model.fit(x_train, y_train, epochs=5)
    loss, acc = model.evaluate(x_test, y_test,verbose=1)
    print("Original model, accuracy: {:5.2f}%".format(100*acc))
    
    # Save the weights
    model.save_weights('./checkpoints/my_checkpoint')
    
    # Restore the weights
    model = create_model()
    model.load_weights('./checkpoints/my_checkpoint')
    
    loss,acc = model.evaluate(x_test, y_test)
    print("Restored model, accuracy: {:5.2f}%".format(100*acc))
    

    Saving and restoring using keras checkpoint callback

    # include the epoch in the file name. (uses `str.format`)
    checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_path, verbose=1, save_weights_only=True,
        # Save weights, every 5-epochs.
        period=5)
    
    model = create_model()
    model.save_weights(checkpoint_path.format(epoch=0))
    model.fit(train_images, train_labels,
              epochs = 50, callbacks = [cp_callback],
              validation_data = (test_images,test_labels),
              verbose=0)
    
    latest = tf.train.latest_checkpoint(checkpoint_dir)
    
    new_model = create_model()
    new_model.load_weights(latest)
    loss, acc = new_model.evaluate(test_images, test_labels)
    print("Restored model, accuracy: {:5.2f}%".format(100*acc))
    

    saving model with custom metrics

    import tensorflow as tf
    from tensorflow import keras
    mnist = tf.keras.datasets.mnist
    
    (x_train, y_train),(x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    # Custom Loss1 (for example) 
    @tf.function() 
    def customLoss1(yTrue,yPred):
      return tf.reduce_mean(yTrue-yPred) 
    
    # Custom Loss2 (for example) 
    @tf.function() 
    def customLoss2(yTrue, yPred):
      return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 
    
    def create_model():
      model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(512, activation=tf.nn.relu),  
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation=tf.nn.softmax)
        ])
      model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy', customLoss1, customLoss2])
      return model
    
    # Create a basic model instance
    model=create_model()
    
    # Fit and evaluate model 
    model.fit(x_train, y_train, epochs=1)
    loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
    print("Original model, accuracy: {:5.2f}%".format(100*acc))
    
    model.save("./model.h5")
    
    new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})
    

    Saving keras model with custom ops

    When we have custom ops as in the following case (tf.tile), we need to create a function and wrap with a Lambda layer. Otherwise, model cannot be saved.

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.layers import Input, Lambda
    from tensorflow.keras import Model
    
    def my_fun(a):
      out = tf.tile(a, (1, tf.shape(a)[0]))
      return out
    
    a = Input(shape=(10,))
    #out = tf.tile(a, (1, tf.shape(a)[0]))
    out = Lambda(lambda x : my_fun(x))(a)
    model = Model(a, out)
    
    x = np.zeros((50,10), dtype=np.float32)
    print(model(x).numpy())
    
    model.save('my_model.h5')
    
    #load the model
    new_model=tf.keras.models.load_model("my_model.h5")
    

    I think I have covered a few of the many ways of saving tf.keras model. However, there are many other ways. Please comment below if you see your use case is not covered above. Thanks!

    0 讨论(0)
  • 2020-11-21 12:02

    There are two parts to the model, the model definition, saved by Supervisor as graph.pbtxt in the model directory and the numerical values of tensors, saved into checkpoint files like model.ckpt-1003418.

    The model definition can be restored using tf.import_graph_def, and the weights are restored using Saver.

    However, Saver uses special collection holding list of variables that's attached to the model Graph, and this collection is not initialized using import_graph_def, so you can't use the two together at the moment (it's on our roadmap to fix). For now, you have to use approach of Ryan Sepassi -- manually construct a graph with identical node names, and use Saver to load the weights into it.

    (Alternatively you could hack it by using by using import_graph_def, creating variables manually, and using tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) for each variable, then using Saver)

    0 讨论(0)
  • 2020-11-21 12:02

    In the new version of tensorflow 2.0, the process of saving/loading a model is a lot easier. Because of the Implementation of the Keras API, a high-level API for TensorFlow.

    To save a model: Check the documentation for reference: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

    tf.keras.models.save_model(model_name, filepath, save_format)
    

    To load a model:

    https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

    model = tf.keras.models.load_model(filepath)
    
    0 讨论(0)
  • 2020-11-21 12:03

    Wherever you want to save the model,

    self.saver = tf.train.Saver()
    with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                ...
                self.saver.save(sess, filename)
    

    Make sure, all your tf.Variable have names, because you may want to restore them later using their names. And where you want to predict,

    saver = tf.train.import_meta_graph(filename)
    name = 'name given when you saved the file' 
    with tf.Session() as sess:
          saver.restore(sess, name)
          print(sess.run('W1:0')) #example to retrieve by variable name
    

    Make sure that saver runs inside the corresponding session. Remember that, if you use the tf.train.latest_checkpoint('./'), then only the latest check point will be used.

    0 讨论(0)
  • 2020-11-21 12:04

    I'm on Version:

    tensorflow (1.13.1)
    tensorflow-gpu (1.13.1)
    

    Simple way is

    Save:

    model.save("model.h5")
    

    Restore:

    model = tf.keras.models.load_model("model.h5")
    
    0 讨论(0)
提交回复
热议问题