Tensorflow: how to save/restore a model?

前端 未结 26 2542
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
相关标签:
26条回答
  • 2020-11-21 12:19

    Here is a simple example using Tensorflow 2.0 SavedModel format (which is the recommended format, according to the docs) for a simple MNIST dataset classifier, using Keras functional API without too much fancy going on:

    # Imports
    import tensorflow as tf
    from tensorflow.keras.layers import Input, Dense, Flatten
    from tensorflow.keras.models import Model
    import matplotlib.pyplot as plt
    
    # Load data
    mnist = tf.keras.datasets.mnist # 28 x 28
    (x_train,y_train), (x_test, y_test) = mnist.load_data()
    
    # Normalize pixels [0,255] -> [0,1]
    x_train = tf.keras.utils.normalize(x_train,axis=1)
    x_test = tf.keras.utils.normalize(x_test,axis=1)
    
    # Create model
    input = Input(shape=(28,28), dtype='float64', name='graph_input')
    x = Flatten()(input)
    x = Dense(128, activation='relu')(x)
    x = Dense(128, activation='relu')(x)
    output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
    model = Model(inputs=input, outputs=output)
    
    model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['accuracy'])
    
    # Train
    model.fit(x_train, y_train, epochs=3)
    
    # Save model in SavedModel format (Tensorflow 2.0)
    export_path = 'model'
    tf.saved_model.save(model, export_path)
    
    # ... possibly another python program 
    
    # Reload model
    loaded_model = tf.keras.models.load_model(export_path) 
    
    # Get image sample for testing
    index = 0
    img = x_test[index] # I normalized the image on a previous step
    
    # Predict using the signature definition (Tensorflow 2.0)
    predict = loaded_model.signatures["serving_default"]
    prediction = predict(tf.constant(img))
    
    # Show results
    print(np.argmax(prediction['graph_output']))  # prints the class number
    plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image
    
    

    What is serving_default?

    It's the name of the signature def of the tag you selected (in this case, the default serve tag was selected). Also, here explains how to find the tag's and signatures of a model using saved_model_cli.

    Disclaimers

    This is just a basic example if you just want to get it up and running, but is by no means a complete answer - maybe I can update it in the future. I just wanted to give a simple example using the SavedModel in TF 2.0 because I haven't seen one, even this simple, anywhere.

    @Tom's answer is a SavedModel example, but it will not work on Tensorflow 2.0, because unfortunately there are some breaking changes.

    @Vishnuvardhan Janapati's answer says TF 2.0, but it's not for SavedModel format.

    0 讨论(0)
  • 2020-11-21 12:20

    If it is an internally saved model, you just specify a restorer for all variables as

    restorer = tf.train.Saver(tf.all_variables())
    

    and use it to restore variables in a current session:

    restorer.restore(self._sess, model_file)
    

    For the external model you need to specify the mapping from the its variable names to your variable names. You can view the model variable names using the command

    python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt
    

    The inspect_checkpoint.py script can be found in './tensorflow/python/tools' folder of the Tensorflow source.

    To specify the mapping, you can use my Tensorflow-Worklab, which contains a set of classes and scripts to train and retrain different models. It includes an example of retraining ResNet models, located here

    0 讨论(0)
  • 2020-11-21 12:20

    All the answers here are great, but I want to add two things.

    First, to elaborate on @user7505159's answer, the "./" can be important to add to the beginning of the file name that you are restoring.

    For example, you can save a graph with no "./" in the file name like so:

    # Some graph defined up here with specific names
    
    saver = tf.train.Saver()
    save_file = 'model.ckpt'
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, save_file)
    

    But in order to restore the graph, you may need to prepend a "./" to the file_name:

    # Same graph defined up here
    
    saver = tf.train.Saver()
    save_file = './' + 'model.ckpt' # String addition used for emphasis
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, save_file)
    

    You will not always need the "./", but it can cause problems depending on your environment and version of TensorFlow.

    It also want to mention that the sess.run(tf.global_variables_initializer()) can be important before restoring the session.

    If you are receiving an error regarding uninitialized variables when trying to restore a saved session, make sure you include sess.run(tf.global_variables_initializer()) before the saver.restore(sess, save_file) line. It can save you a headache.

    0 讨论(0)
  • 2020-11-21 12:23

    In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to https://www.tensorflow.org/programmers_guide/meta_graph.

    Save the model

    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
    w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
    tf.add_to_collection('vars', w1)
    tf.add_to_collection('vars', w2)
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my-model')
    # `save` method will call `export_meta_graph` implicitly.
    # you will get saved graph files:my-model.meta
    

    Restore the model

    sess = tf.Session()
    new_saver = tf.train.import_meta_graph('my-model.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
        v_ = sess.run(v)
        print(v_)
    
    0 讨论(0)
  • 2020-11-21 12:23

    You can also take this easier way.

    Step 1: initialize all your variables

    W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
    B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
    
    Similarly, W2, B2, W3, .....
    

    Step 2: save the session inside model Saver and save it

    model_saver = tf.train.Saver()
    
    # Train the model and save it in the end
    model_saver.save(session, "saved_models/CNN_New.ckpt")
    

    Step 3: restore the model

    with tf.Session(graph=graph_cnn) as session:
        model_saver.restore(session, "saved_models/CNN_New.ckpt")
        print("Model restored.") 
        print('Initialized')
    

    Step 4: check your variable

    W1 = session.run(W1)
    print(W1)
    

    While running in different python instance, use

    with tf.Session() as sess:
        # Restore latest checkpoint
        saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
    
        # Initalize the variables
        sess.run(tf.global_variables_initializer())
    
        # Get default graph (supply your custom graph if you have one)
        graph = tf.get_default_graph()
    
        # It will give tensor object
        W1 = graph.get_tensor_by_name('W1:0')
    
        # To get the value (numpy array)
        W1_value = session.run(W1)
    
    0 讨论(0)
  • 2020-11-21 12:23

    As Yaroslav said, you can hack restoring from a graph_def and checkpoint by importing the graph, manually creating variables, and then using a Saver.

    I implemented this for my personal use, so I though I'd share the code here.

    Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

    (This is, of course, a hack, and there is no guarantee that models saved this way will remain readable in future versions of TensorFlow.)

    0 讨论(0)
提交回复
热议问题