Can't import frozen graph with BatchNorm layer

前端 未结 3 1664
-上瘾入骨i
-上瘾入骨i 2021-02-20 13:41

I have trained a Keras model based on this repo.

After the training I save the model as checkpoint files like this:

 sess=tf.keras.backend.get_session(         


        
3条回答
  •  情书的邮戳
    2021-02-20 13:44

    Just resolved the same issue. I connected this few answers: 1, 2, 3 and realized that issue originated from batchnorm layer working state: training or learning. So, in order to resolve that issue you just need to place one line before loading your model:

    keras.backend.set_learning_phase(0)
    

    Complete example, to export model

    import tensorflow as tf
    from tensorflow.python.framework import graph_io
    from tensorflow.keras.applications.inception_v3 import InceptionV3
    
    
    def freeze_graph(graph, session, output):
        with graph.as_default():
            graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
            graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
            graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)
    
    tf.keras.backend.set_learning_phase(0) # this line most important
    
    base_model = InceptionV3()
    
    session = tf.keras.backend.get_session()
    
    INPUT_NODE = base_model.inputs[0].op.name
    OUTPUT_NODE = base_model.outputs[0].op.name
    freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])
    

    to load *.pb model:

    from PIL import Image
    import numpy as np
    import tensorflow as tf
    
    # https://i.imgur.com/tvOB18o.jpg
    im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
    im = np.array(im) / 255.0
    im = im[None, ...]
    
    graph_def = tf.GraphDef()
    
    with tf.gfile.GFile("frozen_model.pb", "rb") as f:
        graph_def.ParseFromString(f.read())
    
    graph = tf.Graph()
    
    with graph.as_default():
        net_inp, net_out = tf.import_graph_def(
            graph_def, return_elements=["input_1", "predictions/Softmax"]
        )
        with tf.Session(graph=graph) as sess:
            out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
            print(np.argmax(out))
    

提交回复
热议问题