Epoch counter with TensorFlow Dataset API

前端 未结 4 870
悲&欢浪女
悲&欢浪女 2020-12-19 01:27

I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable

相关标签:
4条回答
  • 2020-12-19 02:16

    To add to @mrry's great answer, if you want to stay within the tf.data pipeline and also want to track the iteration within each epoch you can try my solution below. If you have non-unit batch size I guess you would have to add the line data = data.batch(bs).

    import tensorflow as tf
    import itertools
    
    def step_counter(): 
        for i in itertools.count(): yield i
    
    num_examples = 3
    num_epochs = 2
    num_iters = num_examples * num_epochs
    
    features = tf.data.Dataset.range(num_examples)
    labels = tf.data.Dataset.range(num_examples)
    data = tf.data.Dataset.zip((features, labels))
    data = data.shuffle(num_examples)
    
    step = tf.data.Dataset.from_generator(step_counter, tf.int32)
    data = tf.data.Dataset.zip((data, step))
    
    epoch = tf.data.Dataset.range(num_epochs)
    data = epoch.flat_map(
        lambda i: tf.data.Dataset.zip(
            (data, tf.data.Dataset.from_tensors(i).repeat())))
    
    data = data.repeat(num_epochs)
    it = data.make_one_shot_iterator()
    example = it.get_next()
    
    with tf.Session() as sess:
        for _ in range(num_iters):
            ((x, y), st), ep = sess.run(example)
            print(f'step {st} \t epoch {ep} \t x {x} \t y {y}')
    

    Prints:

    step 0   epoch 0     x 2     y 2
    step 1   epoch 0     x 0     y 0
    step 2   epoch 0     x 1     y 1
    step 0   epoch 1     x 2     y 2
    step 1   epoch 1     x 0     y 0
    step 2   epoch 1     x 1     y 1
    
    0 讨论(0)
  • 2020-12-19 02:21

    I extended the example code of numerica to batches and replaced the itertool part:

    num_examples = 5
    num_epochs = 4
    batch_size = 2
    num_iters = int(num_examples * num_epochs / batch_size)
    
    features = tf.data.Dataset.range(num_examples)
    labels = tf.data.Dataset.range(num_examples)
    
    data = tf.data.Dataset.zip((features, labels))
    data = data.shuffle(num_examples)
    
    epoch = tf.data.Dataset.range(num_epochs)
    data = epoch.flat_map(
        lambda i: tf.data.Dataset.zip((
            data,
            tf.data.Dataset.from_tensors(i).repeat(),
            tf.data.Dataset.range(num_examples)
        ))
    )
    
    # to flatten the nested datasets
    data = data.map(lambda samples, *cnts: samples+cnts )
    data = data.batch(batch_size=batch_size)
    
    it = data.make_one_shot_iterator()
    x, y, ep, st = it.get_next()
    
    with tf.Session() as sess:
        for _ in range(num_iters):
            x_, y_, ep_, st_ = sess.run([x, y, ep, st])
            print(f'step {st_}\t epoch {ep_} \t x {x_} \t y {y_}')
    
    0 讨论(0)
  • 2020-12-19 02:23

    TL;DR: Replace the definition of epoch_counter with the following:

    epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                    trainable=False, use_resource=True)
    

    There are some limitations around using TensorFlow variables inside tf.data.Dataset transformations. The principle limitation is that all variables must be "resource variables" and not the older "reference variables"; unfortunately tf.Variable still creates "reference variables" for backwards compatibility reasons.

    Generally speaking, I wouldn't recommend using variables in a tf.data pipeline if it's possible to avoid it. For example, you might be able to use Dataset.range() to define an epoch counter, and then do something like:

    epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
    dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
        (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))
    

    The above snippet attaches an epoch counter to every value as a second component.

    0 讨论(0)
  • 2020-12-19 02:29

    The line data = data.repeat(num_epochs) results in repeating the already for num_epochs repeated dataset (also the epoch counter). Can easily be obtained by replacing for _ in range(num_iters): with for _ in range(num_iters+1):.

    0 讨论(0)
提交回复
热议问题