resetting a Tensorflow graph after OutOfRangeError when using Dataset

前端 未结 1 797
借酒劲吻你
借酒劲吻你 2021-01-22 14:38

I am trying to use the from_generator interface for the Dataset API to inject multiple \"rounds\" of input into a graph.

On my first attempt, I used the repeat() functio

相关标签:
1条回答
  • 2021-01-22 14:46

    I think the problem stems from using tf.contrib.data.Dataset (which supports reinitialization) with tf.train.batch_join() (which uses TensorFlow queues and queue-runners, and hence does not support reinitialization).

    I'm not completely clear what your code is doing, but I think you can implement the entire pipeline as a Dataset. Replace the following fragment of code:

    my_iterator = MyIterator(iterations=iterations)
    dataset = ds.Dataset.from_generator(my_iterator, 
    output_types=my_iterator.output_types, 
    output_shapes=my_iterator.output_shapes)
    #dataset = dataset.repeat(count=repetitions)
    iterator = dataset.make_initializable_iterator()
    next_elem = iterator.get_next()
    
    #change constant to 1 or 2 or something to see that the batching is more predictable
    ripple_adds = [(tf.stack((next_elem[0], next_elem[1] + constant)),) 
    for constant in ripple_add_coefficients]
    batch = tf.train.batch_join(ripple_adds, batch_size=batch_size, 
    enqueue_many=False, name="sink_queue")
    

    ...with something like the following:

    my_iterator = MyIterator(iterations=iterations)
    dataset = tf.contrib.data.from_generator(my_iterator,
                                             output_types=my_iterator.output_types,
                                             output_shapes=my_iterator.output_shapes)
    
    def ripple_add_map_func(x, y):
      return (tf.contrib.data.Dataset.range(num_ripples)
              .map(lambda r: tf.stack([x, y + r])))
    
    dataset = dataset.flat_map(ripple_add_map_func).batch(batch_size)
    
    iterator = dataset.make_initializable_iterator()
    batch = iterator.get_next()
    
    0 讨论(0)
提交回复
热议问题