How to save and restore a tf.estimator.Estimator model with export_savedmodel?

╄→尐↘猪︶ㄣ 提交于 2019-12-24 01:23:55

问题


I started using Tensorflow recently and I try to get use to tf.estimator.Estimator objects. I would like to do something a priori quite natural: after having trained my classifier, i.e. an instance of tf.estimator.Estimator (with the train method), I would like to save it in a file (whatever the extension) and then reload it later to predict the labels for some new data. Since the official documentation recommends to use Estimator APIs, I guess something as important as that should be implemented and documented.

I saw on some other page that the method to do that is export_savedmodel (see the official documentation) but I simply don't understand the documentation. There is no explanation of how to use this method. What is the argument serving_input_fn? I never encountered it in the Creating Custom Estimators tutorial or in any of the tutorials that I read. By doing some googling, I discovered that around a year ago the estimators where defined using an other class (tf.contrib.learn.Estimator) and it looks like the tf.estimator.Estimator is reusing some of the previous APIs. But I don't find clear explanations in the documentation about it.

Could someone please give me a toy example? Or explain me how to define/find this serving_input_fn?

And then how would be load the trained classifier again?

Thank you for your help!

Edit: I discovered that one doesn't necessarily need to use export_savemodel to save the model. It is actually done automatically. Then if we define later a new estimator having the same model_dir argument, it will also automatically restore the previous estimator, as explained here.


回答1:


As you figured out, estimator automatically saves an restores the model for you during the training. export_savemodel might be useful if you want to deploy you model to the field (for example providing the best model for Tensorflow Serving).

Here is a simple example:

est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)

def serving_input_fn(): inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])} return tf.estimator.export.ServingInputReceiver(inputs, inputs)

Basically serving_input_fn is responsible for replacing dataset pipelines with a placeholder. In the deployment you can feed data to this placeholder as the input to your model for inference or prediction.



来源:https://stackoverflow.com/questions/51330841/how-to-save-and-restore-a-tf-estimator-estimator-model-with-export-savedmodel

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!