I\'m using tf.train.shuffle_batch() to create batches of input images. It includes a min_after_dequeue parameter that makes sure there\'s a specified number of elements inside t
You are correct that running the RandomShuffleQueue.close() operation will stop the dequeuing threads from blocking when there are fewer than min_after_dequeue
elements in the queue.
The tf.train.shuffle_batch() function creates a tf.train.QueueRunner that performs operations on the queue in a background thread. If you start it as follows, passing a tf.train.Coordinator, you will be able to close the queue cleanly (based on the example here):
sess = tf.Session()
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess, coord=coord)
while not coord.should_stop():
sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(threads)
Here's the code that I eventually got to work, although with a bunch of warnings that elements I enqueued were cancelled.
lv = tf.constant(label_list)
label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
# if eval_data:
# num_epochs = 1
# else:
# num_epochs = None
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])
reader = tf.WholeFileReader()
result.key, value = reader.read(file_fifo)
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()
images, label_batch = tf.train.shuffle_batch(
[result.uint8image, result.label],
batch_size=FLAGS.batch_size,
num_threads=num_preprocess_threads,
capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
min_after_dequeue=FLAGS.min_queue_size)
#in eval file:
label_enqueue, images, labels = load_input.inputs()
#restore from checkpoint in between
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
true_count = 0 # Counts the number of correct predictions.
total_sample_count = num_iter * FLAGS.batch_size
sess.run(label_enqueue)
step = 0
while step < num_iter and not coord.should_stop():
end_epoch = False
if step > 0:
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
#check if not enough elements in queue
size = qr._queue.size().eval()
if size - FLAGS.batch_size < FLAGS.min_queue_size:
end_epoch = True
if end_epoch:
#enqueue more so that we can finish
sess.run(label_enqueue)
#actually run step
predictions = sess.run([top_k_op])
There is an optional argument allow_smaller_final_batch
"allow_smaller_final_batch: (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue."