Can't import frozen graph with BatchNorm layer

前端 未结 3 1665
-上瘾入骨i
-上瘾入骨i 2021-02-20 13:41

I have trained a Keras model based on this repo.

After the training I save the model as checkpoint files like this:

 sess=tf.keras.backend.get_session(         


        
3条回答
  •  梦谈多话
    2021-02-20 13:49

    This is bug with Tensorflow 1.1x and as another answer stated, it is because of the internal batch norm learning vs inference state. In TF 1.14.0 you actually get a cryptic error when trying to freeze a batch norm layer.

    Using set_learning_phase(0) will put the batch norm layer (and probably others like dropout) into inference mode and thus the batch norm layer will not work during training, leading to reduced accuracy.

    My solution is this:

    1. Create the model using a function (do not use K.set_learning_phase(0)):
    def create_model():
        inputs = Input(...)
        ...
        return model
    
    model = create_model()
    
    1. Train model
    2. Save weights: model.save_weights("weights.h5")
    3. Clear session (important so layer names are the same) and set learning phase to 0:
    K.clear_session()
    K.set_learning_phase(0)
    
    1. Recreate model and load weights:
    model = create_model()
    model.load_weights("weights.h5")
    
    1. Freeze as before

提交回复
热议问题