问题
I'd like to speed up my training routine that uses the Estimator API with input_fn wrote using tf.data.Dataset
.
My implementation takes 2 second to prepare a batch of data and then runs training on GPU for 1 sec, and then start over preparing a batch. Which is really inefficient.
I'm looking for a way to prepare the batches asynchronously and upload them to GPU to speed up the training. Or alternatively for a way to cache datasets between invocations of input_fn
(the dataset.cache()
doesn't seems to be a good choice as the dataset has to be recreated on each input_fn invocation).
Here is a simplified version of my code:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
dataset = dataset.map(_post_process, num_parallel_calls=num_map_threads)
dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
dataset = dataset.batch(128)
dataset = dataset.repeat(epochs) # to iterate over the training set forever
iterator = dataset.dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
I've noticed that the Estimator API is under active development and in the master branch of tensorflow the input_fn can return datasets already, so maybe I'm asking too early and this feature isn't ready yet. But if so, please provide a ticket where this implementation can be tracked.
回答1:
Using tf.data.Dataset.cache()
is indeed not a good choice since it will cache the whole dataset into memory, which takes time and might overflow your memory.
The way to go is to use tf.data.Dataset.prefetch() at the end of your pipeline, which will always make sure that the data pipeline holds buffer_size
elements. It is usually enough to have buffer_size = 1
at the end:
dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1) # prefetch one batch
As explained by @mrry in this answer, you can also try to increase the number of prefetched batches a bit.
Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.
If you still have a slow input pipeline compared to your GPU computations, you need to increase the number of threads working in parallel using the num_parallel_calls
argument of tf.data.Dataset.map().
回答2:
A few points to add to Olivier's answer, mostly from this post:
repeat
beforeshuffle
is slightly faster, at the downside of blurred epoch boundaries. This may be significant in rare cases, but I doubt it.shuffle
beforemap
ping - this reduces the memory foot print of your shuffle buffer size, since it only needs to buffer the filenames rather than the file contents.- it makes more sense to me to apply the third map transform to the output of
get_next()
rather than the dataset - not sure if that affects speed much. You could also consider putting both other map calls in the same one to reduce scheduling issues. - experiment with
repeat
beforebatch
ing. Probably won't make a difference, but might be minor. If yourepeat
beforeshuffle
as mentioned above you'll have to. - as mentioned by Olivier, use
prefetch
.
Code with modifications:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.repeat(epochs)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
def combined_map_fn(*args):
return _post_process(_read_wav(*args))
dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)
iterator = dataset.dataset.make_one_shot_iterator()
wavs, labels = iterator.get_next()
features = {'wav': wavs}
return features, labels
来源:https://stackoverflow.com/questions/48068013/how-to-speed-up-batch-preparation-when-using-estimators-api-combined-with-tf-dat