Can't import frozen graph with BatchNorm layer

前端 未结 3 1690
-上瘾入骨i
-上瘾入骨i 2021-02-20 13:41

I have trained a Keras model based on this repo.

After the training I save the model as checkpoint files like this:

 sess=tf.keras.backend.get_session(         


        
相关标签:
3条回答
  • 2021-02-20 13:44

    Just resolved the same issue. I connected this few answers: 1, 2, 3 and realized that issue originated from batchnorm layer working state: training or learning. So, in order to resolve that issue you just need to place one line before loading your model:

    keras.backend.set_learning_phase(0)
    

    Complete example, to export model

    import tensorflow as tf
    from tensorflow.python.framework import graph_io
    from tensorflow.keras.applications.inception_v3 import InceptionV3
    
    
    def freeze_graph(graph, session, output):
        with graph.as_default():
            graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
            graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
            graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)
    
    tf.keras.backend.set_learning_phase(0) # this line most important
    
    base_model = InceptionV3()
    
    session = tf.keras.backend.get_session()
    
    INPUT_NODE = base_model.inputs[0].op.name
    OUTPUT_NODE = base_model.outputs[0].op.name
    freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])
    

    to load *.pb model:

    from PIL import Image
    import numpy as np
    import tensorflow as tf
    
    # https://i.imgur.com/tvOB18o.jpg
    im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
    im = np.array(im) / 255.0
    im = im[None, ...]
    
    graph_def = tf.GraphDef()
    
    with tf.gfile.GFile("frozen_model.pb", "rb") as f:
        graph_def.ParseFromString(f.read())
    
    graph = tf.Graph()
    
    with graph.as_default():
        net_inp, net_out = tf.import_graph_def(
            graph_def, return_elements=["input_1", "predictions/Softmax"]
        )
        with tf.Session(graph=graph) as sess:
            out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
            print(np.argmax(out))
    
    0 讨论(0)
  • 2021-02-20 13:48

    Thanks for pointing the main issue! I found that keras.backend.set_learning_phase(0) to be not working sometimes, at least in my case.

    Another approach might be: for l in keras_model.layers: l.trainable = False

    0 讨论(0)
  • 2021-02-20 13:49

    This is bug with Tensorflow 1.1x and as another answer stated, it is because of the internal batch norm learning vs inference state. In TF 1.14.0 you actually get a cryptic error when trying to freeze a batch norm layer.

    Using set_learning_phase(0) will put the batch norm layer (and probably others like dropout) into inference mode and thus the batch norm layer will not work during training, leading to reduced accuracy.

    My solution is this:

    1. Create the model using a function (do not use K.set_learning_phase(0)):
    def create_model():
        inputs = Input(...)
        ...
        return model
    
    model = create_model()
    
    1. Train model
    2. Save weights: model.save_weights("weights.h5")
    3. Clear session (important so layer names are the same) and set learning phase to 0:
    K.clear_session()
    K.set_learning_phase(0)
    
    1. Recreate model and load weights:
    model = create_model()
    model.load_weights("weights.h5")
    
    1. Freeze as before
    0 讨论(0)
提交回复
热议问题