I noticed there are batch normalization functions already in the api for tensorflow. One thing I don\'t understand though, is how to to change the procedure between training and
You are right, the tf.nn.batch_normalization
provides just the basic functionality for implementing batch normalization. You have to add the extra logic to keep track of moving means and variances during training, and use the trained means and variances during inference. You can look at this example for a very general implementation, but a quick version that doesn't use gamma
is here :
beta = tf.Variable(tf.zeros(shape), name='beta')
moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
trainable=False)
moving_variance = tf.Variable(tf.ones(shape),
name='moving_variance',
trainable=False)
control_inputs = []
if is_training:
mean, variance = tf.nn.moments(image, [0, 1, 2])
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, self.decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, self.decay)
control_inputs = [update_moving_mean, update_moving_variance]
else:
mean = moving_mean
variance = moving_variance
with tf.control_dependencies(control_inputs):
return tf.nn.batch_normalization(
image, mean=mean, variance=variance, offset=beta,
scale=None, variance_epsilon=0.001)