How can I return the same batch twice from a tensorflow dataset iterator?

匆匆过客 提交于 2019-12-10 07:54:43

问题


I am converting some legacy code to use the Dataset API - this code uses feed_dict to feed one batch to the train operation (actually three times) and then recalculates the losses for display using the same batch. So I need to have an iterator that returns the exact same batch two (or several) times. Unfortunately, I can't seem to find a way of doing it with tensorflow datasets - is it possible?


回答1:


You can repeat individual elements of a Dataset using Dataset.flat_map(), Dataset.from_tensors() and Dataset.repeat() together. For example, to repeat elements twice:

NUM_REPEATS = 2
dataset = tf.data.Dataset.range(10)  # ...or the output of `.batch()`, etc.

# Repeat each element of `dataset` NUM_REPEATS times.
dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(NUM_REPEATS))


来源:https://stackoverflow.com/questions/49358750/how-can-i-return-the-same-batch-twice-from-a-tensorflow-dataset-iterator

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!