What does batch, repeat, and shuffle do with TensorFlow Dataset?

后端 未结 3 981
后悔当初
后悔当初 2020-12-12 15:03

I\'m currently learning TensorFlow but i come across a confusion within this code:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = datase         


        
相关标签:
3条回答
  • 2020-12-12 15:40

    Imagine, you have a dataset: [1, 2, 3, 4, 5, 6], then:

    How ds.shuffle() works

    dataset.shuffle(buffer_size=3) will allocate a buffer of size 3 for picking random entries. This buffer will be connected to the source dataset. We could image it like this:

    Random buffer
       |
       |   Source dataset where all other elements live
       |         |
       ↓         ↓
    [1,2,3] <= [4,5,6]
    

    Let's assume that the entry 2 was taken from the random buffer. Free space is filled by the next element from the source buffer, that is 4:

    2 <= [1,3,4] <= [5,6]
    

    We continue reading till nothing is left:

    1 <= [3,4,5] <= [6]
    5 <= [3,4,6] <= []
    3 <= [4,6]   <= []
    6 <= [4]      <= []
    4 <= []      <= []
    

    How ds.repeat() works

    As soon as all the entries are read from the dataset and you try to read the next element, the dataset will throw an error. That's where ds.repeat() comes into play. It will re-initialize the dataset, making it again like this:

    [1,2,3] <= [4,5,6]
    

    What will ds.batch() produce

    The ds.batch() will take first batch_size entries and make a batch out of them. So, batch size of 3 for our example dataset will produce two batch records:

    [2,1,5]
    [3,6,4]
    

    As we have a ds.repeat() before the batch, the generation of the data will continue. But the order of the elements will be different, due to the ds.random(). What should be taken into account is that 6 will never be present in the first batch, due to the size of the random buffer.

    0 讨论(0)
  • 2020-12-12 15:53

    The following methods in tf.Dataset :

    1. repeat( count=0 ) The method repeats the dataset count number of times.
    2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
    3. batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.
    0 讨论(0)
  • 2020-12-12 15:54

    An example that shows looping over epochs. Upon running this script notice the difference in

    • dataset_gen1 - shuffle operation produces more random outputs (this may be more useful while running machine learning experiments)
    • dataset_gen2 - lack of shuffle operation produces elements in sequence

    Other additions in this script

    • tf.data.experimental.sample_from_datasets - used to combine two datasets. Note that the shuffle operation in this case shall create a buffer that samples equally from both datasets.
    import os
    
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
    os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"
    
    import tensorflow as tf
    if len(tf.config.list_physical_devices('GPU')):
        tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
    
    class Augmentations:
    
        def __init__(self):
            pass
    
        @tf.function
        def filter_even(self, x):
            if x % 2 == 0:
                return False
            else:
                return True
    
    class Dataset:
    
        def __init__(self, aug, range_min=0, range_max=100):
            self.range_min = range_min
            self.range_max = range_max
            self.aug = aug
    
        def generator(self):
            dataset = tf.data.Dataset.from_generator(self._generator
                            , output_types=(tf.float32), args=())
    
            dataset = dataset.filter(self.aug.filter_even)
    
            return dataset
        
        def _generator(self):
            for item in range(self.range_min, self.range_max):
                yield(item)
    
    # Can be used when you have multiple datasets that you wish to combine
    class ZipDataset:
    
        def __init__(self, datasets):
            self.datasets = datasets
            self.datasets_generators = []
        
        def generator(self):
            for dataset in self.datasets:
                self.datasets_generators.append(dataset.generator())
            return tf.data.experimental.sample_from_datasets(self.datasets_generators)
    
    if __name__ == "__main__":
        aug = Augmentations()
        dataset1 = Dataset(aug, 0, 100)
        dataset2 = Dataset(aug, 100, 200)
        dataset = ZipDataset([dataset1, dataset2])
    
        epochs = 2
        shuffle_buffer = 10
        batch_size = 4
        prefetch_buffer = 5
    
        dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
        # dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence 
    
        for epoch in range(epochs):
            print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
            for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
                print (X)
            
            # Do some stuff at end of loop
    
    0 讨论(0)
提交回复
热议问题