Tensorflow `tf.layers.batch_normalization` doesn't add update ops to `tf.GraphKeys.UPDATE_OPS`

后端 未结 1 1068
醉梦人生
醉梦人生 2021-02-15 01:51

The following code (copy/paste runnable) illustrates using tf.layers.batch_normalization.

import tensorflow          


        
相关标签:
1条回答
  • 2021-02-15 02:22

    Just change your code to be in training mode (by setting the training flag to True) as mentioned in the quote:

    Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op.

     import tensorflow as tf
     bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
     print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    

    will output:

    [< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>, 
     < tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]
    

    and Gamma and Beta end up in the TRAINABLE_VARIABLES collection:

    print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
    
    [<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>, 
     <tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]
    
    0 讨论(0)
提交回复
热议问题