TensorFlow: does tf.train.batch automatically load the next batch when the batch has finished training?

后端 未结 2 550
别跟我提以往
别跟我提以往 2020-12-25 14:30

For instance, after I have created my operations, fed the batch data through the operation and run the operation, does tf.train.batch automatically feed in another batch of

2条回答
  •  一生所求
    2020-12-25 15:30

    ... does tf.train.batch automatically feeds in another batch of data to the session?

    No. Nothing happens automatically. You must call sess.run(...) again to load a new batch.

    Does this mean even without a loop, the next batch could be automatically fed?

    No. tf.train.batch(..) will always load batch_size tensors. If you have for example 100 images and a batch_size=30 then you will have 3*30 batches as in you can call sess.run(batch) three times before the input queue will start from the beginning (or stop if epoch=1). This means that you miss out 100-3*30=10 samples from training. In case you do not want to miss them you can do tf.train.batch(..., allow_smaller_final_batch=True) so now you will have 3x 30-sample-batches and 1x 10-sample-batch before the input queue will restart.

    Let me also elaborate with a code sample:

    queue = tf.train.string_input_producer(filenames,
            num_epochs=1) # only iterate through all samples in dataset once
    
    reader = tf.TFRecordReader() # or any reader you need
    _, example = reader.read(queue)
    
    image, label = your_conversion_fn(example)
    
    # batch will now load up to 100 image-label-pairs on sess.run(...)
    # most tf ops are tuned to work on batches
    # this is faster and also gives better result on e.g. gradient calculation
    batch = tf.train.batch([image, label], batch_size=100)
    
    with tf.Session() as sess:
        # "boilerplate" code
        sess.run([
            tf.local_variables_initializer(),
            tf.global_variables_initializer(),
        ])
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
        try:
            # in most cases coord.should_stop() will return True
            # when there are no more samples to read
            # if num_epochs=0 then it will run for ever
            while not coord.should_stop():
                # will start reading, working data from input queue
                # and "fetch" the results of the computation graph
                # into raw_images and raw_labels
                raw_images, raw_labels = sess.run([images, labels])
        finally:
            coord.request_stop()
            coord.join(threads)
    

提交回复
热议问题