问题
I am using Stochastic Weight Averaging (SWA) with Batch Normalization layers in Tensorflow 2.2. For Batch Norm I use tf.keras.layers.BatchNormalization
. For SWA I use my own code to average the weights (I wrote my code before tfa.optimizers.SWA
appeared). I have read in multiple sources that if using batch norm and SWA we must run a forward pass to make certain data (running mean and st dev of activation weights and/or momentum values?) available to the batch norm layers. What I do not understand - despite a lot of reading - is exactly what needs to be done and how. Specifically:
- When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?
- When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?
- Is this process performed magically by the
tfa.optimizers.SWA
class?
回答1:
When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?
At the end of training. Think of it like this, SWA is performed by swapping your final weights with a running average. But all batch norm layers are still calculated based on statistics from your old weights. So we need to run a forward pass to let them catch up.
When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?
During a normal forward pass (prediction) the running mean and standard deviation will not be updated. So what we actually need to do is to train the network, but not update the weights. This is what the paper refers to when it says to run the forward pass in "training mode".
The easiest way to achieve this that I know is to train one additional epoch with learning rate set to 0.
Is this process performed magically by the tfa.optimizers.SWA class?
I don't know. But if you are using Tensorflow Keras then I have made this Keras SWA callback that does it like in the paper including the learning rate schedules.
来源:https://stackoverflow.com/questions/62855224/how-implement-batch-norm-with-swa-in-tensorflow