For instance, after I have created my operations, fed the batch data through the operation and run the operation, does tf.train.batch automatically feed in another batch of
... does tf.train.batch automatically feeds in another batch of data to the session?
No. Nothing happens automatically. You must call sess.run(...)
again to load a new batch.
Does this mean even without a loop, the next batch could be automatically fed?
No. tf.train.batch(..)
will always load batch_size
tensors. If you have for example 100 images and a batch_size=30
then you will have 3*30 batches as in you can call sess.run(batch)
three times before the input queue will start from the beginning (or stop if epoch=1
). This means that you miss out 100-3*30=10
samples from training. In case you do not want to miss them you can do tf.train.batch(..., allow_smaller_final_batch=True)
so now you will have 3x 30-sample-batches and 1x 10-sample-batch before the input queue will restart.
Let me also elaborate with a code sample:
queue = tf.train.string_input_producer(filenames,
num_epochs=1) # only iterate through all samples in dataset once
reader = tf.TFRecordReader() # or any reader you need
_, example = reader.read(queue)
image, label = your_conversion_fn(example)
# batch will now load up to 100 image-label-pairs on sess.run(...)
# most tf ops are tuned to work on batches
# this is faster and also gives better result on e.g. gradient calculation
batch = tf.train.batch([image, label], batch_size=100)
with tf.Session() as sess:
# "boilerplate" code
sess.run([
tf.local_variables_initializer(),
tf.global_variables_initializer(),
])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
# in most cases coord.should_stop() will return True
# when there are no more samples to read
# if num_epochs=0 then it will run for ever
while not coord.should_stop():
# will start reading, working data from input queue
# and "fetch" the results of the computation graph
# into raw_images and raw_labels
raw_images, raw_labels = sess.run([images, labels])
finally:
coord.request_stop()
coord.join(threads)