How to combine multiple datasets into one dataset?

断了今生、忘了曾经 提交于 2019-12-24 11:22:31

问题


Suppose I have 3 tfrecord files, namely neg.tfrecord, pos1.tfrecord, pos2.tfrecord.

I use

dataset = tf.data.TFRecordDataset(tfrecord_file)

this code creates 3 Dataset objects.

My batch size is 400, including 200 neg data, 100 pos1 data, and 100 pos2 data. How can I get the desired dataset?

I will use this dataset object in keras.fit() (Eager Execution).

My tensorflow's version is 1.13.1.

Before, I tried to get the iterator for each dataset, and then manually concat after getting the data, but it was inefficient and the GPU utilization was not high.


回答1:


You can use interleave

filenames = [tfrecord_file1, tfrecord_file2]
dataset = (Dataset.from_tensor_slices(filenames).interleave(lambda x:TFRecordDataset(x)
dataset = dataset.map(parse_fn)
...

Or you can even try parallel interleave. See https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#interleave https://www.tensorflow.org/api_docs/python/tf/data/experimental/parallel_interleave



来源:https://stackoverflow.com/questions/55154836/how-to-combine-multiple-datasets-into-one-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!