How implement Batch Norm with SWA in Tensorflow?

偶尔善良 提交于 2021-02-11 06:13:06

问题


I am using Stochastic Weight Averaging (SWA) with Batch Normalization layers in Tensorflow 2.2. For Batch Norm I use tf.keras.layers.BatchNormalization. For SWA I use my own code to average the weights (I wrote my code before tfa.optimizers.SWA appeared). I have read in multiple sources that if using batch norm and SWA we must run a forward pass to make certain data (running mean and st dev of activation weights and/or momentum values?) available to the batch norm layers. What I do not understand - despite a lot of reading - is exactly what needs to be done and how. Specifically:

  1. When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?
  2. When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?
  3. Is this process performed magically by the tfa.optimizers.SWA class?

回答1:


When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?

At the end of training. Think of it like this, SWA is performed by swapping your final weights with a running average. But all batch norm layers are still calculated based on statistics from your old weights. So we need to run a forward pass to let them catch up.

When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?

During a normal forward pass (prediction) the running mean and standard deviation will not be updated. So what we actually need to do is to train the network, but not update the weights. This is what the paper refers to when it says to run the forward pass in "training mode".

The easiest way to achieve this that I know is to train one additional epoch with learning rate set to 0.

Is this process performed magically by the tfa.optimizers.SWA class?

I don't know. But if you are using Tensorflow Keras then I have made this Keras SWA callback that does it like in the paper including the learning rate schedules.



来源:https://stackoverflow.com/questions/62855224/how-implement-batch-norm-with-swa-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!