问题
The following code shows one way that works and the other that fails.
The BatchNorm on axis=0 should not depend on the batchsize or if it does it should be explicitly stated as such in the docs.
In [118]: tf.__version__
Out[118]: '2.0.0-beta1'
class M(tf.keras.models.Model):
import numpy as np
import tensorflow as tf
class M(tf.keras.Model):
def __init__(self, axis):
super().__init__()
self.layer = tf.keras.layers.BatchNormalization(axis=axis, scale=False, center=True, input_shape=(6,))
def call(self, x):
out = self.layer(x)
return out
def fails():
m = M(axis=0)
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
def ok():
m = M(axis=1)
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
EDIT:
The axis in the args is not the axis you think it is.
回答1:
As it has been stated in this answer and the Keras doc, the axis
argument indicates the feature axis. This totally makes sense because we want to do feature-wise normalization i.e. to normalize each feature over the whole input batch (this is in accordance with feature-wise normalization we may do on images, e.g. subtracting the "mean pixel" from all the images of a dataset).
Now, the fails()
method you have written fails on this line:
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
That's because you have set the feature axis as 0, i.e. the first axis, when building the model and therefore when the following lines get executed before the above code:
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
the layer's weight would be built based on 3 features (don't forget you have indicated the feature axis as 0, so there would be 3 features in an input of shape (3,6)
). So when you give it an input tensor of shape (2,6)
it would correctly raise an error because there are 2 features in that tensor and therefore the normalization could not be done due to this mismatch.
On the other hand, the ok()
method works because feature axis is the last axis and therefore both input tensors have the same number of features, i.e. 6. So normalization could be done in both cases for all the features.
来源:https://stackoverflow.com/questions/57202668/keras-batchnormalization-only-works-for-constant-batch-dim-when-axis-0