how to load and use a saved model on tensorflow?

后端 未结 4 1516
无人及你
无人及你 2021-01-31 19:22

I have found 2 ways to save a model in Tensorflow: tf.train.Saver() and SavedModelBuilder. However, I can\'t find documentation on using the mo

4条回答
  •  醉酒成梦
    2021-01-31 19:54

    Here's the code snippet to load and restore/predict models using the simple_save

    #Save the model:
    tf.saved_model.simple_save(sess, export_dir=saveModelPath,
                                       inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,
                                               "isTrainingBool": isTraining},
                                       outputs={"predictedClassBatch": predClass})
    

    Note that using simple_save sets certain default values (this can be seen at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py)

    Now, to restore and use the inputs/outputs dict:

    from tensorflow.python.saved_model import tag_constants
    from tensorflow.python.saved_model import signature_constants
    
    with tf.Session() as sess:
      model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.
    
      inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name
      inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)
    
      inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name
      inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)
    
      isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name
      isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)
    
      outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name
      outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)
    
      outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})
    
      print("predicted classes:", outPred)
    

    Note: the default signature_def was needed to make use of the tensor names specified in the input & output dicts.

提交回复
热议问题