Parallel threads with TensorFlow Dataset API and flat_map

前端 未结 1 1338
伪装坚强ぢ
伪装坚强ぢ 2021-02-07 03:38

I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. With the old interface I could specify the num_threads argument to the t

相关标签:
1条回答
  • 2021-02-07 04:15

    To the best of my knowledge, at the moment flat_map does not offer parallelism options. Given that the bulk of the computation is done in pre_processing_func, what you might use as a workaround is a parallel map call followed by some buffering, and then using a flat_map call with an identity lambda function that takes care of flattening the output.

    In code:

    NUM_THREADS = 5
    BUFFER_SIZE = 1000
    
    def pre_processing_func(data_):
        # data-augmentation here
        # generate new samples starting from the sample `data_`
        artificial_samples = generate_from_sample(data_)
        return atificial_samples
    
    dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
                      map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                      prefetch(BUFFER_SIZE).
                      flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
                      shuffle(BUFFER_SIZE)) # my addition, probably necessary though
    

    Note (to myself and whoever will try to understand the pipeline):

    Since pre_processing_func generates an arbitrary number of new samples starting from the initial sample (organised in matrices of shape (?, 512)), the flat_map call is necessary to turn all the generated matrices into Datasets containing single samples (hence the tf.data.Dataset.from_tensor_slices(x) in the lambda) and then flatten all these datasets into one big Dataset containing individual samples.

    It's probably a good idea to .shuffle() that dataset, or generated samples will be packed together.

    0 讨论(0)
提交回复
热议问题