Batch sequential data coming from multiple TFRecord files with tf.data

☆樱花仙子☆ 提交于 2019-12-14 02:30:02

问题


Let's consider a dataset split into multiple TFRecord files:

  • 1.tfrecord,
  • 2.tfrecord,
  • etc.

I would like to generate sequences of size t (say 3) consisting of consecutive elements from the same TFRecord file, I do not want a sequence to have elements belonging to different TFRecord files.

For instance, if we have two TFRecord files containing data like:

  • 1.tfrecord: {0, 1, 2, ..., 7}
  • 2.tfrecord: {1000, 1001, 1002, ..., 1007}

without any shuffling, I would like to get the following batches:

  • 1st batch: 0, 1, 2,
  • 2nd batch: 1, 2, 3,
  • ...
  • i-th batch: 5, 6, 7,
  • (i+1)-th batch: 1000, 1001, 1002,
  • (i+2)-th batch: 1001, 1002, 1003,
  • ...
  • j-th batch: 1005, 1006, 1007,
  • (j+1)-th batch: 0, 1, 2,
  • etc.

I know how to generate sequence data using tf.data.Dataset.window or tf.data.Dataset.batch, but I do not know how to prevent a sequence from containing element from different files.

I'm looking for a scalable solutions, i.e. the solution should work with hundred of TFRecord files.

Below is my failed attempt (fully reproducible example):

import tensorflow as tf

# ****************************
# Generate toy TF Record files

def _create_example(i):
    example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
    return tf.train.Example(features=example)

def parse_fn(serialized_example):
    return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']


num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
    with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
        for j in range(records_per_file):
            example = _create_example(j + 1000 * i)
            writer.write(example.SerializeToString())
# ****************************
# ****************************


data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
            .map(lambda x: parse_fn(x))

data = data.window(3, 1, 1, True)\
           .repeat(-1)\
           .flat_map(lambda x: x.batch(3))\
           .batch(16)

data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

which outputs:

[[   0    1    2]   # good
 [   1    2    3]   # good
 [   2    3    4]   # good
 [   3    4    5]   # good
 [   4    5    6]   # good
 [   5    6    7]   # good
 [   6    7 1000]   # bad – mix of elements from 0.tfrecord and 1.tfrecord
 [   7 1000 1001]   # bad
 [1000 1001 1002]   # good
 [1001 1002 1003]   # good
 [1002 1003 1004]   # good
 [1003 1004 1005]   # good
 [1004 1005 1006]   # good
 [1005 1006 1007]   # good
 [   0    1    2]   # good
 [   1    2    3]]  # good

来源:https://stackoverflow.com/questions/55110489/batch-sequential-data-coming-from-multiple-tfrecord-files-with-tf-data

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!