Saving a TF2 keras model with custom signature defs

前端 未结 1 905
时光取名叫无心
时光取名叫无心 2021-01-11 21:41

I have a Keras (sequential) model that could be saved with custom signature defs in Tensorflow 1.13 as follows:

from tensorflow.saved_model.utils import buil         


        
1条回答
  •  走了就别回头了
    2021-01-11 22:10

    The solution is to create a tf.Module with functions for each signature definition:

    class MyModule(tf.Module):
      def __init__(self, model, other_variable):
        self.model = model
        self._other_variable = other_variable
    
      @tf.function(input_signature=[tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)])
      def score(self, waveform):
        result = self.model(waveform)
        return { "scores": results }
    
      @tf.function(input_signature=[])
      def metadata(self):
        return { "other_variable": self._other_variable }
    

    And then save the module (not the model):

    module = MyModule(model, 1234)
    tf.saved_model.save(module, export_path, signatures={ "score": module.score, "metadata": module.metadata })
    

    Tested with Keras model on TF2.

    0 讨论(0)
提交回复
热议问题