Suppose I have 3 tfrecord files, namely neg.tfrecord
, pos1.tfrecord
, pos2.tfrecord
I use
dataset = tf.data.TFRecordDataset(tfrecord_file)
this code creates 3 Dataset objects.
My batch size is 400, including 200 neg data, 100 pos1 data, and 100 pos2 data. How can I get the desired dataset?
I will use this dataset object in keras.fit() (Eager Execution).
My tensorflow's version is 1.13.1.
Before, I tried to get the iterator for each dataset, and then manually concat after getting the data, but it was inefficient and the GPU utilization was not high.
You can use interleave
filenames = [tfrecord_file1, tfrecord_file2]
dataset = (Dataset.from_tensor_slices(filenames).interleave(lambda x:TFRecordDataset(x)
dataset = dataset.map(parse_fn)
Or you can even try parallel interleave. See https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#interleave https://www.tensorflow.org/api_docs/python/tf/data/experimental/parallel_interleave