Ways to implement multi-GPU BN layers with synchronizing means and vars

前端 未结 3 1653
盖世英雄少女心
盖世英雄少女心 2021-02-05 22:53

I\'d like to know the possible ways to implement batch normalization layers with synchronizing batch statistics when training with multi-GPU.

Caffe May

3条回答
  •  忘了有多久
    2021-02-05 23:18

    I'm not sure if I fully understand your question, but provided you set up your variable scope properly, the tf.GraphKeys.UPDATE_OPS collection should automatically have the update ops for batch_norm for each of your towers. If all of the update_ops are applied synchronously, they will be implicitly averaged by the parameter server, all you have to do is make sure the updates are applied before you average and apply gradients. (If I understand your intentions correctly).

    Because of variable scope each set of update ops will update the same variables, so to synchronize the update ops all you need to do is gate your gradient calculation on the complete set of update ops. You should also encapsulate all of your batch norm layers in a single name_scope to avoid grabbing any extraneous ops in UPDATE_OPS. Code skeleton below:

    update_ops = []
    for i, device in enumerate(devices):
      with tf.variable_scope('foo', reuse=bool(i > 0)):
        with tf.name_scope('tower_%d' % i) as name_scope:
          with tf.device(device):
            # Put as many batch_norm layers as you want here
          update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                              name_scope))
    # make gradient calculation ops here
    with tf.device(averaging_device):
      with tf.control_dependencies(update_ops):
        # average and apply gradients.
    

    If you wanna try this on some existing code, try just deleting the if i == 0 line here: https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/cifar10_main.py#L115

    You're going to see some slow down (we usually only use one tower to compute batch norm statistics for this reason), but it should do what you want.

提交回复
热议问题