Tensor is not an element of this graph; deploying Keras model

前端 未结 6 849
佛祖请我去吃肉
佛祖请我去吃肉 2020-12-28 15:30

Im deploying a keras model and sending the test data to the model via a flask api. I have two files:

First: My Flask App:

# Let\'s startup the Flask          


        
6条回答
  •  别那么骄傲
    2020-12-28 16:29

    It's so much simpler to wrap your keras model in a class and that class can keep track of it's own graph and session. This prevents the problems that having multiple threads/processes/models can cause which is almost certainly the cause of your issue. While other solutions will work this is by far the most general, scalable and catch all. Use this one:

    import os
    from keras.models import model_from_json
    from keras import backend as K
    import tensorflow as tf
    import logging
    
    logger = logging.getLogger('root')
    
    
    class NeuralNetwork:
        def __init__(self):
            self.session = tf.Session()
            self.graph = tf.get_default_graph()
            # the folder in which the model and weights are stored
            self.model_folder = os.path.join(os.path.abspath("src"), "static")
            self.model = None
            # for some reason in a flask app the graph/session needs to be used in the init else it hangs on other threads
            with self.graph.as_default():
                with self.session.as_default():
                    logging.info("neural network initialised")
    
        def load(self, file_name=None):
            """
            :param file_name: [model_file_name, weights_file_name]
            :return:
            """
            with self.graph.as_default():
                with self.session.as_default():
                    try:
                        model_name = file_name[0]
                        weights_name = file_name[1]
    
                        if model_name is not None:
                            # load the model
                            json_file_path = os.path.join(self.model_folder, model_name)
                            json_file = open(json_file_path, 'r')
                            loaded_model_json = json_file.read()
                            json_file.close()
                            self.model = model_from_json(loaded_model_json)
                        if weights_name is not None:
                            # load the weights
                            weights_path = os.path.join(self.model_folder, weights_name)
                            self.model.load_weights(weights_path)
                        logging.info("Neural Network loaded: ")
                        logging.info('\t' + "Neural Network model: " + model_name)
                        logging.info('\t' + "Neural Network weights: " + weights_name)
                        return True
                    except Exception as e:
                        logging.exception(e)
                        return False
    
        def predict(self, x):
            with self.graph.as_default():
                with self.session.as_default():
                    y = self.model.predict(x)
            return y
    

提交回复
热议问题