问题
I started using Tensorflow recently and I try to get use to tf.estimator.Estimator objects. I would like to do something a priori quite natural: after having trained my classifier, i.e. an instance of tf.estimator.Estimator (with the train
method), I would like to save it in a file (whatever the extension) and then reload it later to predict the labels for some new data. Since the official documentation recommends to use Estimator APIs, I guess something as important as that should be implemented and documented.
I saw on some other page that the method to do that is export_savedmodel
(see the official documentation) but I simply don't understand the documentation. There is no explanation of how to use this method. What is the argument serving_input_fn
? I never encountered it in the Creating Custom Estimators tutorial or in any of the tutorials that I read. By doing some googling, I discovered that around a year ago the estimators where defined using an other class (tf.contrib.learn.Estimator
) and it looks like the tf.estimator.Estimator is reusing some of the previous APIs. But I don't find clear explanations in the documentation about it.
Could someone please give me a toy example? Or explain me how to define/find this serving_input_fn
?
And then how would be load the trained classifier again?
Thank you for your help!
Edit: I discovered that one doesn't necessarily need to use export_savemodel to save the model. It is actually done automatically. Then if we define later a new estimator having the same model_dir argument, it will also automatically restore the previous estimator, as explained here.
回答1:
As you figured out, estimator automatically saves an restores the model for you during the training. export_savemodel might be useful if you want to deploy you model to the field (for example providing the best model for Tensorflow Serving).
Here is a simple example:
est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)
def serving_input_fn():
inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
Basically serving_input_fn is responsible for replacing dataset pipelines with a placeholder. In the deployment you can feed data to this placeholder as the input to your model for inference or prediction.
来源:https://stackoverflow.com/questions/51330841/how-to-save-and-restore-a-tf-estimator-estimator-model-with-export-savedmodel