Keras BatchNormalization only works for constant batch dim when axis=0?

断了今生、忘了曾经 提交于 2019-12-13 04:26:37

问题


The following code shows one way that works and the other that fails.

The BatchNorm on axis=0 should not depend on the batchsize or if it does it should be explicitly stated as such in the docs.

In [118]: tf.__version__
Out[118]: '2.0.0-beta1'



class M(tf.keras.models.Model):
import numpy as np
import tensorflow as tf

class M(tf.keras.Model):

    def __init__(self, axis):
        super().__init__()
        self.layer = tf.keras.layers.BatchNormalization(axis=axis, scale=False, center=True, input_shape=(6,))

    def call(self, x):
        out = self.layer(x)
        return out

def fails():
    m = M(axis=0)
    x = np.random.randn(3, 6).astype(np.float32)
    print(m(x))
    x = np.random.randn(2, 6).astype(np.float32)
    print(m(x))

def ok():
    m = M(axis=1)
    x = np.random.randn(3, 6).astype(np.float32)
    print(m(x))
    x = np.random.randn(2, 6).astype(np.float32)
    print(m(x))

EDIT:

The axis in the args is not the axis you think it is.


回答1:


As it has been stated in this answer and the Keras doc, the axis argument indicates the feature axis. This totally makes sense because we want to do feature-wise normalization i.e. to normalize each feature over the whole input batch (this is in accordance with feature-wise normalization we may do on images, e.g. subtracting the "mean pixel" from all the images of a dataset).

Now, the fails() method you have written fails on this line:

x = np.random.randn(2, 6).astype(np.float32)
print(m(x))

That's because you have set the feature axis as 0, i.e. the first axis, when building the model and therefore when the following lines get executed before the above code:

x = np.random.randn(3, 6).astype(np.float32)
print(m(x))

the layer's weight would be built based on 3 features (don't forget you have indicated the feature axis as 0, so there would be 3 features in an input of shape (3,6)). So when you give it an input tensor of shape (2,6) it would correctly raise an error because there are 2 features in that tensor and therefore the normalization could not be done due to this mismatch.

On the other hand, the ok() method works because feature axis is the last axis and therefore both input tensors have the same number of features, i.e. 6. So normalization could be done in both cases for all the features.



来源:https://stackoverflow.com/questions/57202668/keras-batchnormalization-only-works-for-constant-batch-dim-when-axis-0

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!