How do I update moving mean and moving variance in keras BatchNormalization?
I found this in tensorflow documentation, but I don't know where to put train_op
or how to work it with keras models:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize( loss )
No posts I found say what to do with train_op and whether you can use it in model.compile
.
You do not need to manually update the moving mean and variances if you are using the BatchNormalization layer. Keras takes care of updating these parameters during training, and to keep them fixed during testing (by using the model.predict
and model.evaluate
functions, same as with model.fit_generator
and friends).
Keras also keeps track of the learning phase so different codepaths run during training and validation/testing.
If you need just update the weights for existing model with some new values then you can do the following:
w = model.get_layer('batchnorm_layer_name').get_weights()
# Order: [gamma, beta, mean, std]
for j in range(len(w[0])):
gamma = w[0][j]
beta = w[1][j]
run_mean = w[2][j]
run_std = w[3][j]
w[2][j] = new_run_mean_value1
w[3][j] = new_run_std_value2
model.get_layer('batchnorm_layer_name').set_weights(w)
There are two interpretations of the question: the first is assuming that the goal is to use high level training api and this question was answered by Matias Valdenegro.
The second - as discussed in the comments - is whether it is possible to use batch normalization with the standard tensorflow optimizer as discussed here keras a simplified tensorflow interface and the section "Collecting trainable weights and state updates". As mentioned there the update ops are accessible in layer.updates and not in tf.GraphKeys.UPDATE_OPS
, in fact if you have a keras model in tensorflow you can optimize with a standard tensorflow optimizer and batch normalization like this
update_ops = model.updates
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize( loss )
and then use a tensorflow session to fetch the train_op. To distinguish training and evaluation modes of the batch normalization layer you need to feed the learning phase state of the keras engine (see "Different behaviors during training and testing" on the same tutorial page as given above). This would work for example like this
...
# train
lo, _ = tf_sess.run(fetches=[loss, train_step],
feed_dict={tf_batch_data: bd,
tf_batch_labels: bl,
tensorflow.keras.backend.learning_phase(): 1})
...
# eval
lo = tf_sess.run(fetches=[loss],
feed_dict={tf_batch_data: bd,
tf_batch_labels: bl,
tensorflow.keras.backend.learning_phase(): 0})
I tried this in tensorflow 1.12 and it works with models containing batch normalization. Given my existing tensorflow code and in the light of approaching tensorflow version 2.0 I was tempted to use this approach myself, but given that this approach is not being mentioned in the tensorflow documentation I am not sure this will be supported in the long term and I finally have decided to not use it and to invest a little bit more to change the code to use the high level api.
来源:https://stackoverflow.com/questions/50164572/batchnormalization-in-keras