Batch Normalization in tensorflow

后端 未结 1 1889
情歌与酒
情歌与酒 2021-02-06 13:34

I noticed there are batch normalization functions already in the api for tensorflow. One thing I don\'t understand though, is how to to change the procedure between training and

1条回答
  •  囚心锁ツ
    2021-02-06 14:09

    You are right, the tf.nn.batch_normalization provides just the basic functionality for implementing batch normalization. You have to add the extra logic to keep track of moving means and variances during training, and use the trained means and variances during inference. You can look at this example for a very general implementation, but a quick version that doesn't use gamma is here :

      beta = tf.Variable(tf.zeros(shape), name='beta')
      moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
                                     trainable=False)
      moving_variance = tf.Variable(tf.ones(shape),
                                         name='moving_variance',
                                         trainable=False)
      control_inputs = []
      if is_training:
        mean, variance = tf.nn.moments(image, [0, 1, 2])
        update_moving_mean = moving_averages.assign_moving_average(
            moving_mean, mean, self.decay)
        update_moving_variance = moving_averages.assign_moving_average(
            moving_variance, variance, self.decay)
        control_inputs = [update_moving_mean, update_moving_variance]
      else:
        mean = moving_mean
        variance = moving_variance
      with tf.control_dependencies(control_inputs):
        return tf.nn.batch_normalization(
            image, mean=mean, variance=variance, offset=beta,
            scale=None, variance_epsilon=0.001)
    

    0 讨论(0)
提交回复
热议问题