After you train a model in Tensorflow:
Here is a simple example using Tensorflow 2.0 SavedModel format (which is the recommended format, according to the docs) for a simple MNIST dataset classifier, using Keras functional API without too much fancy going on:
# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()
# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)
# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train
model.fit(x_train, y_train, epochs=3)
# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)
# ... possibly another python program
# Reload model
loaded_model = tf.keras.models.load_model(export_path)
# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step
# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))
# Show results
print(np.argmax(prediction['graph_output'])) # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary) # prints the image
What is serving_default
?
It's the name of the signature def of the tag you selected (in this case, the default serve
tag was selected). Also, here explains how to find the tag's and signatures of a model using saved_model_cli
.
Disclaimers
This is just a basic example if you just want to get it up and running, but is by no means a complete answer - maybe I can update it in the future. I just wanted to give a simple example using the SavedModel
in TF 2.0 because I haven't seen one, even this simple, anywhere.
@Tom's answer is a SavedModel example, but it will not work on Tensorflow 2.0, because unfortunately there are some breaking changes.
@Vishnuvardhan Janapati's answer says TF 2.0, but it's not for SavedModel format.
If it is an internally saved model, you just specify a restorer for all variables as
restorer = tf.train.Saver(tf.all_variables())
and use it to restore variables in a current session:
restorer.restore(self._sess, model_file)
For the external model you need to specify the mapping from the its variable names to your variable names. You can view the model variable names using the command
python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt
The inspect_checkpoint.py script can be found in './tensorflow/python/tools' folder of the Tensorflow source.
To specify the mapping, you can use my Tensorflow-Worklab, which contains a set of classes and scripts to train and retrain different models. It includes an example of retraining ResNet models, located here
All the answers here are great, but I want to add two things.
First, to elaborate on @user7505159's answer, the "./" can be important to add to the beginning of the file name that you are restoring.
For example, you can save a graph with no "./" in the file name like so:
# Some graph defined up here with specific names
saver = tf.train.Saver()
save_file = 'model.ckpt'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, save_file)
But in order to restore the graph, you may need to prepend a "./" to the file_name:
# Same graph defined up here
saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_file)
You will not always need the "./", but it can cause problems depending on your environment and version of TensorFlow.
It also want to mention that the sess.run(tf.global_variables_initializer())
can be important before restoring the session.
If you are receiving an error regarding uninitialized variables when trying to restore a saved session, make sure you include sess.run(tf.global_variables_initializer())
before the saver.restore(sess, save_file)
line. It can save you a headache.
In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph
and tf.train.import_meta_graph
according to https://www.tensorflow.org/programmers_guide/meta_graph.
w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
v_ = sess.run(v)
print(v_)
You can also take this easier way.
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
Similarly, W2, B2, W3, .....
Saver
and save itmodel_saver = tf.train.Saver()
# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")
with tf.Session(graph=graph_cnn) as session:
model_saver.restore(session, "saved_models/CNN_New.ckpt")
print("Model restored.")
print('Initialized')
W1 = session.run(W1)
print(W1)
While running in different python instance, use
with tf.Session() as sess:
# Restore latest checkpoint
saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
# Initalize the variables
sess.run(tf.global_variables_initializer())
# Get default graph (supply your custom graph if you have one)
graph = tf.get_default_graph()
# It will give tensor object
W1 = graph.get_tensor_by_name('W1:0')
# To get the value (numpy array)
W1_value = session.run(W1)
As Yaroslav said, you can hack restoring from a graph_def and checkpoint by importing the graph, manually creating variables, and then using a Saver.
I implemented this for my personal use, so I though I'd share the code here.
Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(This is, of course, a hack, and there is no guarantee that models saved this way will remain readable in future versions of TensorFlow.)