I have found 2 ways to save a model in Tensorflow: tf.train.Saver()
and SavedModelBuilder
. However, I can\'t find documentation on using the mo
Here's the code snippet to load and restore/predict models using the simple_save
#Save the model:
tf.saved_model.simple_save(sess, export_dir=saveModelPath,
inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,
"isTrainingBool": isTraining},
outputs={"predictedClassBatch": predClass})
Note that using simple_save sets certain default values (this can be seen at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py)
Now, to restore and use the inputs/outputs dict:
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants
with tf.Session() as sess:
model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.
inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name
inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)
inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name
inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)
isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name
isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)
outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name
outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)
outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})
print("predicted classes:", outPred)
Note: the default signature_def was needed to make use of the tensor names specified in the input & output dicts.