Multi threading in Dataset api

寵の児 提交于 2019-12-20 03:15:57

问题


TL;DR: how to ensure that data is loaded in multi threaded manner when using Dataset api in tensorflow 0.1.4?

Previously I did something like this with my images in disk:

filename_queue = tf.train.string_input_producer(filenames)    
image_reader = tf.WholeFileReader()
_, image_file = image_reader.read(filename_queue)    
imsize = 120    
image = tf.image.decode_jpeg(image_file, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image_r = tf.image.resize_images(image, [imsize, imsize])    
images = tf.train.shuffle_batch([image_r],
    batch_size=20,
    num_threads=30,
    capacity=200,
    min_after_dequeue=0)

This ensures that there will be 20 threads getting data ready for next learning iterations.

Now with the Dataset api I do something like:

dataset = tf.data.Dataset.from_tensor_slices((filenames, filenames_up, filenames_blacked))
dataset = dataset.map(parse_upscaler_corrector_batch)

After this I create an iterator:

sess = tf.Session();
iterator = dataset.make_initializable_iterator();
next_element = iterator.get_next();
sess.run(iterator.initializer); 
value = sess.run(next_element)

Variable value will be passed for further processing.

So how do I ensure that data is being prepared in multui-threading manner here? Where could I read about Dataset api and multi threading data read?


回答1:


So it appears that the way to achieve this is as follows:

dataset = dataset.map(parse_upscaler_corrector_batch, num_parallel_calls=12).prefetch(32).batch(self.ex_config.batch_size)

If one changes num_parallel_calls=12 one can see that both network/hdd load and cpu load either spike or decrease.



来源:https://stackoverflow.com/questions/47653644/multi-threading-in-dataset-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!