Tensorflow: how to save/restore a model?

前端 未结 26 2615
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  闹比i
    闹比i (楼主)
    2020-11-21 12:19

    Here is a simple example using Tensorflow 2.0 SavedModel format (which is the recommended format, according to the docs) for a simple MNIST dataset classifier, using Keras functional API without too much fancy going on:

    # Imports
    import tensorflow as tf
    from tensorflow.keras.layers import Input, Dense, Flatten
    from tensorflow.keras.models import Model
    import matplotlib.pyplot as plt
    
    # Load data
    mnist = tf.keras.datasets.mnist # 28 x 28
    (x_train,y_train), (x_test, y_test) = mnist.load_data()
    
    # Normalize pixels [0,255] -> [0,1]
    x_train = tf.keras.utils.normalize(x_train,axis=1)
    x_test = tf.keras.utils.normalize(x_test,axis=1)
    
    # Create model
    input = Input(shape=(28,28), dtype='float64', name='graph_input')
    x = Flatten()(input)
    x = Dense(128, activation='relu')(x)
    x = Dense(128, activation='relu')(x)
    output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
    model = Model(inputs=input, outputs=output)
    
    model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['accuracy'])
    
    # Train
    model.fit(x_train, y_train, epochs=3)
    
    # Save model in SavedModel format (Tensorflow 2.0)
    export_path = 'model'
    tf.saved_model.save(model, export_path)
    
    # ... possibly another python program 
    
    # Reload model
    loaded_model = tf.keras.models.load_model(export_path) 
    
    # Get image sample for testing
    index = 0
    img = x_test[index] # I normalized the image on a previous step
    
    # Predict using the signature definition (Tensorflow 2.0)
    predict = loaded_model.signatures["serving_default"]
    prediction = predict(tf.constant(img))
    
    # Show results
    print(np.argmax(prediction['graph_output']))  # prints the class number
    plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image
    
    

    What is serving_default?

    It's the name of the signature def of the tag you selected (in this case, the default serve tag was selected). Also, here explains how to find the tag's and signatures of a model using saved_model_cli.

    Disclaimers

    This is just a basic example if you just want to get it up and running, but is by no means a complete answer - maybe I can update it in the future. I just wanted to give a simple example using the SavedModel in TF 2.0 because I haven't seen one, even this simple, anywhere.

    @Tom's answer is a SavedModel example, but it will not work on Tensorflow 2.0, because unfortunately there are some breaking changes.

    @Vishnuvardhan Janapati's answer says TF 2.0, but it's not for SavedModel format.

提交回复
热议问题