I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. With the old interface I could specify the num_threads
argument to the t
To the best of my knowledge, at the moment flat_map
does not offer parallelism options.
Given that the bulk of the computation is done in pre_processing_func
, what you might use as a workaround is a parallel map
call followed by some buffering, and then using a flat_map
call with an identity lambda function that takes care of flattening the output.
In code:
NUM_THREADS = 5
BUFFER_SIZE = 1000
def pre_processing_func(data_):
# data-augmentation here
# generate new samples starting from the sample `data_`
artificial_samples = generate_from_sample(data_)
return atificial_samples
dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
map(pre_processing_func, num_parallel_calls=NUM_THREADS).
prefetch(BUFFER_SIZE).
flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
shuffle(BUFFER_SIZE)) # my addition, probably necessary though
Since pre_processing_func
generates an arbitrary number of new samples starting from the initial sample (organised in matrices of shape (?, 512)
), the flat_map
call is necessary to turn all the generated matrices into Dataset
s containing single samples (hence the tf.data.Dataset.from_tensor_slices(x)
in the lambda) and then flatten all these datasets into one big Dataset
containing individual samples.
It's probably a good idea to .shuffle()
that dataset, or generated samples will be packed together.