Memory reduction Tensorflow TPU v2/v3 bfloat16

*爱你&永不变心* 提交于 2020-05-30 03:12:45

问题


My model is too big to get a batch >64 with the normal v2 TPU devices. On the troubleshooting site it is mentioned that upcoming tensorflow versions will have bfloat16 support. Are the newly supported tf versions 1.9-1.12 capable to use bfloat16 now and if yes, is there a limited set of optimizers I can use? I did not find any further documentation on this but saw the usage of bfloat16 in the tensor2tensor model, so I guess there must be a way.

Furthermore I read that TPU v3 supports bigger models as well but that the model would need minimal changes, but I don't find any documentation what needs to be changed.

I'm already using Adafactor and tried to reduce my layers, if you have any further reduction tips, that would be great too. I'm using picture matrices and word vectors (float32 as of now) as input.


回答1:


You can use bfloat16 with TPUs. There are two main things to do:

  1. Cast the input to bfloat16 in your input pipeline
  2. Surround your network within a bfloat16 scope and cast the outputs as F32 for further calculations.

Here is a code snippet that illustrates the necessary changes:

def input_fn():

  def dataset_parser(self, value):
    """Parse an ImageNet record from a serialized string Tensor."""
    image = self.image_preprocessing_fn(
        image_bytes=image_bytes,
        is_training=self.is_training,
    )

    if self.use_bfloat16:
      image = tf.cast(image, tf.bfloat16)

    return image, label


def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator."""

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=FLAGS.resnet_depth,
        num_classes=LABEL_CLASSES,
        data_format=FLAGS.data_format)
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if FLAGS.precision == 'bfloat16':
    with bfloat16.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

You can also see the second condition illustrated in this TPU model.



来源:https://stackoverflow.com/questions/53458833/memory-reduction-tensorflow-tpu-v2-v3-bfloat16

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!