How to switch between training and validation dataset with tf.MonitoredTrainingSession?

假如想象 提交于 2019-12-20 21:42:10

问题


I want to use feedable iterator design in tensorflow Dataset API, so I can switch to validation data after some training steps. But if I switched to validation data, it will end the whole session.

The following code demonstrate what I want to do:

import tensorflow as tf


graph = tf.Graph()
with graph.as_default():
    training_ds = tf.data.Dataset.range(32).batch(4)
    validation_ds = tf.data.Dataset.range(8).batch(4)

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_ds.output_types, training_ds.output_shapes)
    next_element = iterator.get_next()

    training_iterator = training_ds.make_initializable_iterator()
    validation_iterator = validation_ds.make_initializable_iterator()


with graph.as_default():

    with tf.train.MonitoredTrainingSession() as sess:
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())
        sess.run(training_iterator.initializer)
        count_training = 0
        while not sess.should_stop():
            x = sess.run(next_element, feed_dict={handle: training_handle})
            count_training += 1
            print('{} [training] {}'.format(count_training, x.shape))
            # print(x)

            # we do periodic validation
            if count_training % 4 == 0:
                sess.run(validation_iterator.initializer)
                count_validation = 0
                while not sess.should_stop():
                    y = sess.run(next_element, feed_dict={handle: validation_handle})
                    count_validation += 1
                    print('  {} [validation] {}'.format(count_validation, y.shape))
                    # print(y)

The training data has 32 elements, batched with 4, so got 8 batches we do validation every 4 steps, so I expect:

#  1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
#      1 [validation]
#      2 [validation]

but it stops when the first validation is done:

# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]

So, how to use this feedable iterator in tf.MonitoredTrainingSession?


回答1:


I would suggest to catch tf.errors.OutOfRangeError raised at the end of the validation dataset (you can also check the processing multiple epochs section in the official API for another solution using the repeat dataset ):

while not sess.should_stop():
    x = sess.run(next_element, feed_dict={handle: training_handle})
    count_training += 1
    print('{} [training] {}'.format(count_training, x.shape))

    # we do periodic validation
    if count_training % 4 == 0:
        sess.run(validation_iterator.initializer)
        count_validation = 0
        while True:
            try:
                y = sess.run(next_element, feed_dict={handle: validation_handle})
                count_validation += 1
                print('  {} [validation] {}'.format(count_validation, y.shape))
            except tf.errors.OutOfRangeError:
                break

This piece of code prints:

1 [training] (4,)  
2 [training] (4,)  
3 [training] (4,)  
4 [training] (4,)  
  1 [validation] (4,)  
  2 [validation] (4,)  
5 [training] (4,)
6 [training] (4,)
7 [training] (4,)
8 [training] (4,)
  1 [validation] (4,)
  2 [validation] (4,)


来源:https://stackoverflow.com/questions/49095849/how-to-switch-between-training-and-validation-dataset-with-tf-monitoredtrainings

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!