How to iterate a dataset several times using TensorFlow's Dataset API?

随声附和 提交于 2020-01-12 01:48:14

问题


How to output the value in a dataset several times? (dataset is created by Dataset API of TensorFlow)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

Error message:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

How to make this work?


回答1:


First of all I advice you to read Data Set Guide. There is described all the details of DataSet API.

Your question is about iterating over the data several times. Here are two solutions for that:

  1. Iterating all epochs at once, no information about end of individual epochs
import tensorflow as tf

epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)
  1. Second option inform you about ending each of epoch, so you can ex. check validation loss:
import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)



回答2:


If your tensorflow version is 1.3+, I recommend the high-level API tf.train.MonitoredTrainingSession. The sess created by this API can automatically detect tf.errors.OutOfRangeError with sess.should_stop(). For most of training situations, you need to shuffle data and get a batch each step, I have added these in the following code.

import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        value = sess.run(next_element)
        num_batch += 1
        print("Num Batch: ", num_batch)



回答3:


Try this

while True:
  try:
    print(sess.run(value))
  except tf.errors.OutOfRangeError:
    break

Whenever the dataset iterator reaches the end of the data, it will raise tf.errors.OutOfRangeError, you can catch it with except and start the dataset from the beginning.



来源:https://stackoverflow.com/questions/47067401/how-to-iterate-a-dataset-several-times-using-tensorflows-dataset-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!