How to use py_func with a function that returns dict

后端 未结 1 489
有刺的猬
有刺的猬 2021-02-08 03:45

I\'m writing an input pipeline using tf.data.Dataset. I\'d like to use python code to load and transform my samples, the code returns a dictionary of tensors. Unfor

相关标签:
1条回答
  • 2021-02-08 03:48

    This answer is in response to Celso Franca's comment.

    I did find a way but not returning a dict but rather using tf_example.SerializeToString().

    The two functions were used for processing BERT input on the fly. It worked greate and saved me many hours of pre-processing upfront, while not losing any performance in the training process.

    def _convert(label, text):
        """Decodes a csv-line to a TensorFlow Example, serialized as a string."""
        np_label = label.numpy()
        np_text = text.numpy()
        tokens_a = tokenizer.tokenize(np_text)
        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > seq_length - 2:
            tokens_a = tokens_a[0: (seq_length - 2)]
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)
    
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)
    
        # Zero-pad up to the sequence length.
        while len(input_ids) < seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
    
        assert len(input_ids) == seq_length
        assert len(input_mask) == seq_length
        assert len(segment_ids) == seq_length
    
        label_id = label_map[np_label]
        features = collections.OrderedDict()
        features["input_ids"] = create_int_feature(input_ids)
        features["input_mask"] = create_int_feature(input_mask)
        features["segment_ids"] = create_int_feature(segment_ids)
        features["label_ids"] = create_int_feature([label_id])
        features["is_real_example"] = create_int_feature([int(True)])
        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        # tf.py_function only accepts true tf datatypes like string
        return tf_example.SerializeToString()
    
      def _decode_record(record):
        """Decodes a record to a TensorFlow example."""
        example = tf.parse_single_example(record, name_to_features)
        # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
        # So cast all int64 to int32.
        for name in list(example.keys()):
          t = example[name]
          if t.dtype == tf.int64:
            t = tf.to_int32(t)
          example[name] = t
        return example
    
      def input_fn(params):
        """The actual input function."""
        filenames = tf.data.Dataset.list_files(file_pattern)
        label_col = processor.get_label_col()
        text_col = processor.get_text_col()
        d = filenames.apply(
          tf.contrib.data.parallel_interleave(
              lambda filename: tf.data.experimental.CsvDataset(filename,
                [tf.float32, tf.string],
                select_cols=[label_col, text_col],
                field_delim=delimiter,
                header=True),
              cycle_length=2))
        if is_training:
          d = d.repeat()
          d = d.shuffle(buffer_size=100)
        d = d.map(lambda label, text: tf.py_function(_convert, [label, text], tf.string))
        d = d.map(_decode_record)
        d = d.batch(batch_size=params["batch_size"], drop_remainder=drop_remainder)
        return d
    
    
    0 讨论(0)
提交回复
热议问题