问题
How to output the value in a dataset several times? (dataset is created by Dataset API of TensorFlow)
import tensorflow as tf
dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10
for i in range(epoch):
for j in range(100):
value = sess.run(next_element)
assert j == value
print(j)
Error message:
tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]
How to make this work?
回答1:
First of all I advice you to read Data Set Guide. There is described all the details of DataSet API.
Your question is about iterating over the data several times. Here are two solutions for that:
- Iterating all epochs at once, no information about end of individual epochs
import tensorflow as tf
epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
num_batch = 0
j = 0
while True:
try:
value = sess.run(next_element)
assert j == value
j += 1
num_batch += 1
if j > 99: # new epoch
j = 0
except tf.errors.OutOfRangeError:
break
print ("Num Batch: ", num_batch)
- Second option inform you about ending each of epoch, so you can ex. check validation loss:
import tensorflow as tf
epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()
num_batch = 0
for e in range(epoch):
print ("Epoch: ", e)
j = 0
sess.run(iterator.initializer)
while True:
try:
value = sess.run(next_element)
assert j == value
j += 1
num_batch += 1
except tf.errors.OutOfRangeError:
break
print ("Num Batch: ", num_batch)
回答2:
If your tensorflow version is 1.3+, I recommend the high-level API tf.train.MonitoredTrainingSession
. The sess
created by this API can automatically detect tf.errors.OutOfRangeError
with sess.should_stop()
. For most of training situations, you need to shuffle data and get a batch each step, I have added these in the following code.
import tensorflow as tf
epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32) # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
while not sess.should_stop():
value = sess.run(next_element)
num_batch += 1
print("Num Batch: ", num_batch)
回答3:
Try this
while True:
try:
print(sess.run(value))
except tf.errors.OutOfRangeError:
break
Whenever the dataset iterator reaches the end of the data, it will raise tf.errors.OutOfRangeError, you can catch it with except and start the dataset from the beginning.
来源:https://stackoverflow.com/questions/47067401/how-to-iterate-a-dataset-several-times-using-tensorflows-dataset-api