Keras Batchnormalization Only Works For Constant Batch Dim When Axis=0?
Solution 1:
As it has been stated in this answer and the Keras doc, the axis
argument indicates the feature axis. This totally makes sense because we want to do feature-wise normalization i.e. to normalize each feature over the whole input batch (this is in accordance with feature-wise normalization we may do on images, e.g. subtracting the "mean pixel" from all the images of a dataset).
Now, the fails()
method you have written fails on this line:
x = np.random.randn(2, 6).astype(np.float32)
print(m(x))
That's because you have set the feature axis as 0, i.e. the first axis, when building the model and therefore when the following lines get executed before the above code:
x = np.random.randn(3, 6).astype(np.float32)
print(m(x))
the layer's weight would be built based on 3 features (don't forget you have indicated the feature axis as 0, so there would be 3 features in an input of shape (3,6)
). So when you give it an input tensor of shape (2,6)
it would correctly raise an error because there are 2 features in that tensor and therefore the normalization could not be done due to this mismatch.
On the other hand, the ok()
method works because feature axis is the last axis and therefore both input tensors have the same number of features, i.e. 6. So normalization could be done in both cases for all the features.
Post a Comment for "Keras Batchnormalization Only Works For Constant Batch Dim When Axis=0?"