Parallel threads with TensorFlow Dataset API and flat_map

前端 未结 1 1323
伪装坚强ぢ 2021-02-07 03:38

I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. With the old interface I could specify the num_threads argument to the t

  • 2021-02-07 04:15

    To the best of my knowledge, at the moment flat_map does not offer parallelism options. Given that the bulk of the computation is done in pre_processing_func, what you might use as a workaround is a parallel map call followed by some buffering, and then using a flat_map call with an identity lambda function that takes care of flattening the output.

    In code:

    BUFFER_SIZE = 1000
    def pre_processing_func(data_):
        # data-augmentation here
        # generate new samples starting from the sample `data_`
        artificial_samples = generate_from_sample(data_)
        return atificial_samples
    dataset_source = (
                      map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                      flat_map(lambda *x :
                      shuffle(BUFFER_SIZE)) # my addition, probably necessary though

    Note (to myself and whoever will try to understand the pipeline):

    Since pre_processing_func generates an arbitrary number of new samples starting from the initial sample (organised in matrices of shape (?, 512)), the flat_map call is necessary to turn all the generated matrices into Datasets containing single samples (hence the in the lambda) and then flatten all these datasets into one big Dataset containing individual samples.

    It's probably a good idea to .shuffle() that dataset, or generated samples will be packed together.

    0 讨论(0)