Tensor Flow shuffle_batch() blocks at end of epoch

前端 未结 3 1532
醉话见心
醉话见心 2021-01-23 00:57

I\'m using tf.train.shuffle_batch() to create batches of input images. It includes a min_after_dequeue parameter that makes sure there\'s a specified number of elements inside t

相关标签:
3条回答
  • 2021-01-23 01:48

    You are correct that running the RandomShuffleQueue.close() operation will stop the dequeuing threads from blocking when there are fewer than min_after_dequeue elements in the queue.

    The tf.train.shuffle_batch() function creates a tf.train.QueueRunner that performs operations on the queue in a background thread. If you start it as follows, passing a tf.train.Coordinator, you will be able to close the queue cleanly (based on the example here):

    sess = tf.Session()
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess, coord=coord)
    
    while not coord.should_stop():
      sess.run(train_op)
    # When done, ask the threads to stop.
    coord.request_stop()
    # And wait for them to actually do it.
    coord.join(threads)
    
    0 讨论(0)
  • 2021-01-23 01:52

    Here's the code that I eventually got to work, although with a bunch of warnings that elements I enqueued were cancelled.

    lv = tf.constant(label_list)
    
    label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
    # if eval_data:
        # num_epochs = 1
    # else:
        # num_epochs = None
    file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
    label_enqueue = label_fifo.enqueue_many([lv])
    
    
    reader = tf.WholeFileReader()
    result.key, value = reader.read(file_fifo)
    image = tf.image.decode_jpeg(value, channels=3)
    image.set_shape([128,128,3])
    result.uint8image = image
    result.label = label_fifo.dequeue()
    
    images, label_batch = tf.train.shuffle_batch(
      [result.uint8image, result.label],
      batch_size=FLAGS.batch_size,
      num_threads=num_preprocess_threads,
      capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
      min_after_dequeue=FLAGS.min_queue_size)
    
    #in eval file:
    label_enqueue, images, labels = load_input.inputs()
    #restore from checkpoint in between
    coord = tf.train.Coordinator()
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))
    
      num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
      true_count = 0  # Counts the number of correct predictions.
      total_sample_count = num_iter * FLAGS.batch_size
    
      sess.run(label_enqueue)
      step = 0
      while step < num_iter and not coord.should_stop():
        end_epoch = False
        if step > 0:
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                #check if not enough elements in queue
                size = qr._queue.size().eval()
                if size - FLAGS.batch_size < FLAGS.min_queue_size:
                    end_epoch = True
        if end_epoch:
            #enqueue more so that we can finish
            sess.run(label_enqueue)
        #actually run step
        predictions = sess.run([top_k_op])
    
    0 讨论(0)
  • 2021-01-23 02:00

    There is an optional argument allow_smaller_final_batch

    "allow_smaller_final_batch: (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue."

    0 讨论(0)
提交回复
热议问题