I\'d like to know the possible ways to implement batch normalization layers with synchronizing batch statistics when training with multi-GPU.
Caffe May
I'm not sure if I fully understand your question, but provided you set up your variable scope properly, the tf.GraphKeys.UPDATE_OPS
collection should automatically have the update ops for batch_norm for each of your towers. If all of the update_ops are applied synchronously, they will be implicitly averaged by the parameter server, all you have to do is make sure the updates are applied before you average and apply gradients. (If I understand your intentions correctly).
Because of variable scope each set of update ops will update the same variables, so to synchronize the update ops all you need to do is gate your gradient calculation on the complete set of update ops. You should also encapsulate all of your batch norm layers in a single name_scope
to avoid grabbing any extraneous ops in UPDATE_OPS
. Code skeleton below:
update_ops = []
for i, device in enumerate(devices):
with tf.variable_scope('foo', reuse=bool(i > 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device):
# Put as many batch_norm layers as you want here
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope))
# make gradient calculation ops here
with tf.device(averaging_device):
with tf.control_dependencies(update_ops):
# average and apply gradients.
If you wanna try this on some existing code, try just deleting the if i == 0
line here: https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/cifar10_main.py#L115
You're going to see some slow down (we usually only use one tower to compute batch norm statistics for this reason), but it should do what you want.
A specialized keras layer SyncBatchNormalization is available Since TF2.2 https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/SyncBatchNormalization
I've figured out a way to implement sync batch norm in pure tensorflow and pure python.
The code makes it possible to train PSPNet on Cityscapes and get comparable performance.