问题
Let's consider a dataset split into multiple TFRecord files:
1.tfrecord
,2.tfrecord
,- etc.
I would like to generate sequences of size t
(say 3
) consisting of consecutive elements from the same TFRecord file, I do not want a sequence to have elements belonging to different TFRecord files.
For instance, if we have two TFRecord files containing data like:
1.tfrecord
:{0, 1, 2, ..., 7}
2.tfrecord
:{1000, 1001, 1002, ..., 1007}
without any shuffling, I would like to get the following batches:
- 1st batch:
0, 1, 2
, - 2nd batch:
1, 2, 3
, - ...
- i-th batch:
5, 6, 7
, - (i+1)-th batch:
1000, 1001, 1002
, - (i+2)-th batch:
1001, 1002, 1003
, - ...
- j-th batch:
1005, 1006, 1007
, - (j+1)-th batch:
0, 1, 2
, - etc.
I know how to generate sequence data using tf.data.Dataset.window
or tf.data.Dataset.batch
, but I do not know how to prevent a sequence from containing element from different files.
I'm looking for a scalable solutions, i.e. the solution should work with hundred of TFRecord files.
Below is my failed attempt (fully reproducible example):
import tensorflow as tf
# ****************************
# Generate toy TF Record files
def _create_example(i):
example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
return tf.train.Example(features=example)
def parse_fn(serialized_example):
return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']
num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
for j in range(records_per_file):
example = _create_example(j + 1000 * i)
writer.write(example.SerializeToString())
# ****************************
# ****************************
data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
.map(lambda x: parse_fn(x))
data = data.window(3, 1, 1, True)\
.repeat(-1)\
.flat_map(lambda x: x.batch(3))\
.batch(16)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
which outputs:
[[ 0 1 2] # good
[ 1 2 3] # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 6] # good
[ 5 6 7] # good
[ 6 7 1000] # bad – mix of elements from 0.tfrecord and 1.tfrecord
[ 7 1000 1001] # bad
[1000 1001 1002] # good
[1001 1002 1003] # good
[1002 1003 1004] # good
[1003 1004 1005] # good
[1004 1005 1006] # good
[1005 1006 1007] # good
[ 0 1 2] # good
[ 1 2 3]] # good
来源:https://stackoverflow.com/questions/55110489/batch-sequential-data-coming-from-multiple-tfrecord-files-with-tf-data