I am trying to use the from_generator interface for the Dataset API to inject multiple \"rounds\" of input into a graph.
On my first attempt, I used the repeat() functio
I think the problem stems from using tf.contrib.data.Dataset
(which supports reinitialization) with tf.train.batch_join()
(which uses TensorFlow queues and queue-runners, and hence does not support reinitialization).
I'm not completely clear what your code is doing, but I think you can implement the entire pipeline as a Dataset
. Replace the following fragment of code:
my_iterator = MyIterator(iterations=iterations)
dataset = ds.Dataset.from_generator(my_iterator,
output_types=my_iterator.output_types,
output_shapes=my_iterator.output_shapes)
#dataset = dataset.repeat(count=repetitions)
iterator = dataset.make_initializable_iterator()
next_elem = iterator.get_next()
#change constant to 1 or 2 or something to see that the batching is more predictable
ripple_adds = [(tf.stack((next_elem[0], next_elem[1] + constant)),)
for constant in ripple_add_coefficients]
batch = tf.train.batch_join(ripple_adds, batch_size=batch_size,
enqueue_many=False, name="sink_queue")
...with something like the following:
my_iterator = MyIterator(iterations=iterations)
dataset = tf.contrib.data.from_generator(my_iterator,
output_types=my_iterator.output_types,
output_shapes=my_iterator.output_shapes)
def ripple_add_map_func(x, y):
return (tf.contrib.data.Dataset.range(num_ripples)
.map(lambda r: tf.stack([x, y + r])))
dataset = dataset.flat_map(ripple_add_map_func).batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()