问题
I have exported a SavedModel
and now I with to load it back in and make a prediction. It was trained with the following features and labels:
F1 : FLOAT32
F2 : FLOAT32
F3 : FLOAT32
L1 : FLOAT32
So say I want to feed in the values 20.9, 1.8, 0.9
get a single FLOAT32
prediction. How do I accomplish this? I have managed to successfully load the model, but I am not sure how to access it to make the prediction call.
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
"/job/export/Servo/1503723455"
)
# How can I predict from here?
# I want to do something like prediction = model.predict([20.9, 1.8, 0.9])
This question is not a duplicate of the question posted here. This question focuses on a minimal example of performing inference on a SavedModel
of any model class (not just limited to tf.estimator
) and the syntax of specifying input and output node names.
回答1:
Assuming you want predictions in Python, SavedModelPredictor is probably the easiest way to load a SavedModel and get predictions. Suppose you save your model like so:
# Build the graph
f1 = tf.placeholder(shape=[], dtype=tf.float32)
f2 = tf.placeholder(shape=[], dtype=tf.float32)
f3 = tf.placeholder(shape=[], dtype=tf.float32)
l1 = tf.placeholder(shape=[], dtype=tf.float32)
output = build_graph(f1, f2, f3, l1)
# Save the model
inputs = {'F1': f1, 'F2': f2, 'F3': f3, 'L1': l1}
outputs = {'output': output_tensor}
tf.contrib.simple_save(sess, export_dir, inputs, outputs)
(The inputs can be any shape and don't even have to be placeholders nor root nodes in the graph).
Then, in the Python program that will use the SavedModel
, we can get predictions like so:
from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(export_dir)
predictions = predict_fn(
{"F1": 1.0, "F2": 2.0, "F3": 3.0, "L1": 4.0})
print(predictions)
This answer shows how to get predictions in Java, C++, and Python (despite the fact that the question is focused on Estimators, the answer actually applies independently of how the SavedModel
is created).
回答2:
For anyone who needs a working example of saving a trained canned model and serving it without tensorflow serving ,I have documented here https://github.com/tettusud/tensorflow-examples/tree/master/estimators
- You can create a predictor from
tf.tensorflow.contrib.predictor.from_saved_model( exported_model_path)
Prepare input
tf.train.Example( features= tf.train.Features( feature={ 'x': tf.train.Feature( float_list=tf.train.FloatList(value=[6.4, 3.2, 4.5, 1.5]) ) } ) )
Here x
is the name of the input that was given in input_receiver_function at the time of exporting.
for eg:
feature_spec = {'x': tf.FixedLenFeature([4],tf.float32)}
def serving_input_receiver_fn():
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[None],
name='input_tensors')
receiver_tensors = {'inputs': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
回答3:
Once the graph is loaded, it is available in the current context and you can feed input data through it to obtain predictions. Each use-case is rather different, but the addition to your code will look something like this:
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
"/job/export/Servo/1503723455"
)
prediction = sess.run(
'prefix/predictions/Identity:0',
feed_dict={
'Placeholder:0': [20.9],
'Placeholder_1:0': [1.8],
'Placeholder_2:0': [0.9]
}
)
print(prediction)
Here, you need to know the names of what your prediction inputs will be. If you did not give them a nave in your serving_fn
, then they default to Placeholder_n
, where n
is the nth feature.
The first string argument of sess.run
is the name of the prediction target. This will vary based on your use case.
回答4:
The constructor of tf.estimator.DNNClassifier
has an argument called warm_start_from
. You can give it the SavedModel
folder name and it will recover your session.
来源:https://stackoverflow.com/questions/45900653/tensorflow-how-to-predict-from-a-savedmodel