BatchNormalization in Keras

Deadly 提交于 2019-12-05 16:06:12

You do not need to manually update the moving mean and variances if you are using the BatchNormalization layer. Keras takes care of updating these parameters during training, and to keep them fixed during testing (by using the model.predict and model.evaluate functions, same as with model.fit_generator and friends).

Keras also keeps track of the learning phase so different codepaths run during training and validation/testing.

If you need just update the weights for existing model with some new values then you can do the following:

w = model.get_layer('batchnorm_layer_name').get_weights()
# Order: [gamma, beta, mean, std]
for j in range(len(w[0])):
    gamma = w[0][j]
    beta = w[1][j]
    run_mean = w[2][j]
    run_std = w[3][j]
    w[2][j] = new_run_mean_value1
    w[3][j] = new_run_std_value2

model.get_layer('batchnorm_layer_name').set_weights(w)

There are two interpretations of the question: the first is assuming that the goal is to use high level training api and this question was answered by Matias Valdenegro.

The second - as discussed in the comments - is whether it is possible to use batch normalization with the standard tensorflow optimizer as discussed here keras a simplified tensorflow interface and the section "Collecting trainable weights and state updates". As mentioned there the update ops are accessible in layer.updates and not in tf.GraphKeys.UPDATE_OPS, in fact if you have a keras model in tensorflow you can optimize with a standard tensorflow optimizer and batch normalization like this

update_ops  = model.updates
with tf.control_dependencies(update_ops):
     train_op = optimizer.minimize( loss )

and then use a tensorflow session to fetch the train_op. To distinguish training and evaluation modes of the batch normalization layer you need to feed the learning phase state of the keras engine (see "Different behaviors during training and testing" on the same tutorial page as given above). This would work for example like this

... 
# train
lo, _ = tf_sess.run(fetches=[loss, train_step],
                    feed_dict={tf_batch_data: bd,
                               tf_batch_labels: bl,
                               tensorflow.keras.backend.learning_phase(): 1})

...

# eval
lo = tf_sess.run(fetches=[loss],
                    feed_dict={tf_batch_data: bd,
                               tf_batch_labels: bl,
                               tensorflow.keras.backend.learning_phase(): 0})

I tried this in tensorflow 1.12 and it works with models containing batch normalization. Given my existing tensorflow code and in the light of approaching tensorflow version 2.0 I was tempted to use this approach myself, but given that this approach is not being mentioned in the tensorflow documentation I am not sure this will be supported in the long term and I finally have decided to not use it and to invest a little bit more to change the code to use the high level api.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!