Dataset API, Iterators and tf.contrib.data.rejection_resample

后端 未结 2 528
半阙折子戏
半阙折子戏 2021-01-03 06:28

[Edit #1 after @mrry comment] I am using the (great & amazing) Dataset API along with tf.contrib.data.rejection_resample to set a specific distribution

相关标签:
2条回答
  • 2021-01-03 07:01

    Here is below a simple example to demonstrate the usage of sample_from_datasets (thanks @Agade for the idea).

    import math
    import tensorflow as tf
    import numpy as np
    
    
    def print_dataset(name, dataset):
        elems = np.array([v.numpy() for v in dataset])
        print("Dataset {} contains {} elements :".format(name, len(elems)))
        print(elems)
    
    
    def combine_datasets_balanced(dataset_smaller, size_smaller, dataset_bigger, size_bigger, batch_size):
        ds_smaller_repeated = dataset_smaller.repeat(count=int(math.ceil(size_bigger / size_smaller)))
        # we repeat the smaller dataset so that the 2 datasets are about the same size
        balanced_dataset = tf.data.experimental.sample_from_datasets([ds_smaller_repeated, dataset_bigger], weights=[0.5, 0.5])
        # each element in the resulting dataset is randomly drawn (without replacement) from dataset even with proba 0.5 or from odd with proba 0.5
        balanced_dataset = balanced_dataset.take(2 * size_bigger).batch(batch_size)
        return balanced_dataset
    
    
    N, M = 3, 10
    even = tf.data.Dataset.range(0, 2 * N, 2).repeat(count=int(math.ceil(M / N)))
    odd = tf.data.Dataset.range(1, 2 * M, 2)
    even_odd = combine_datasets_balanced(even, N, odd, M, 2)
    
    print_dataset("even", even)
    print_dataset("odd", odd)
    print_dataset("even_odd_all", even_odd)
    
    Output :
    
    Dataset even contains 12 elements :  # 12 = 4 x N  (because of .repeat)
    [0 2 4 0 2 4 0 2 4 0 2 4]
    Dataset odd contains 10 elements :
    [ 1  3  5  7  9 11 13 15 17 19]
    Dataset even_odd contains 10 elements :  # 10 = 2 x M / 2  (2xM because of .take(2 * M) and /2 because of .batch(2))
    [[ 0  2]
     [ 1  4]
     [ 0  2]
     [ 3  4]
     [ 0  2]
     [ 4  0]
     [ 5  2]
     [ 7  4]
     [ 0  9]
     [ 2 11]] 
    
    0 讨论(0)
  • 2021-01-03 07:03

    Following @mrry response I could come up with a solution on how to use the Dataset API with tf.contrib.data.rejection_resample (using TF1.3).

    The goal

    Given a feature/label dataset with some distribution, have the input pipeline reshape the distribution to specific target distribution.

    Numerical example

    Lets assume we are building a network to classify some feature into one of 10 classes. And assume we only have 100 features with some random distribution of labels.
    30 features labeled as class 1, 5 features labeled as class 2 and so forth. During training we do not want to prefer class 1 over class 2 so we would like each mini-batch to hold a uniform distribution for all classes.

    The solution

    Using tf.contrib.data.rejection_resample will allow to set a specific distribution for our inputs pipelines.

    In the documentation it says tf.contrib.data.rejection_resample will take

    (1) Dataset - which is the dataset you want to balance

    (2) class_func - which is a function that generates a new numerical labels dataset only from the original dataset

    (3) target_dist - a vector in the size of the number of classes to specificy required new distribution.

    (4) some more optional values - skipped for now

    and as the documentation says it returns a `Dataset.

    It turns out that the shape of the input Dataset is different than the output Dataset shape. As a consequence, the returned Dataset (as implemeted in TF1.3) should be filtered by the user like this:

        balanced_dataset = tf.contrib.data.rejection_resample(input_dataset,
                                                              self.class_mapping_function,
                                                              self.target_distribution)
    
        # Return to the same Dataset shape as was the original input
        balanced_dataset = balanced_dataset.map(lambda _, data: (data))
    

    One note on the Iterator kind. As @mrry explained here, when using stateful objects within the pipeline one should use the initializable iterator and not the one-hot. Note that when using the initializable iterator you should add the init_op to the TABLE_INITIALIZERS or you will recieve this error: "GetNext() failed because the iterator has not been initialized."

    Code example:

    # Creating the iterator, that allows to access elements from the dataset
    if self.use_balancing:
        # For balancing function, we use stateful variables in the sense that they hold current dataset distribution
        # and calculate next distribution according to incoming examples.
        # For dataset pipeline that have state, one_shot iterator will not work, and we are forced to use
        # initializable iterator
        # This should be relaxed in the future.
        # https://stackoverflow.com/questions/44374083/tensorflow-cannot-capture-a-stateful-node-by-value-in-tf-contrib-data-api
        iterator = dataset.make_initializable_iterator()
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
    
    else:
        iterator = dataset.make_one_shot_iterator()
    
    image_batch, label_batch = iterator.get_next()
    

    Does it work ?

    Yes. Here are 2 images from Tensorboard after collection a histogram on the input pipeline labels. The original input labels were uniformly distributed. Scenario A: Trying to achieve the following 10-class distribution: [0.1,0.4,0.05,0.05,0.05,0.05,0.05,0.05,0.1,0.1]

    And the result:

    Scenario B: Trying to achieve the following 10-class distribution: [0.1,0.1,0.05,0.05,0.05,0.05,0.05,0.05,0.4,0.1]

    And the result:

    0 讨论(0)
提交回复
热议问题