Tensorflow: Load data in multiple threads on cpu

后端 未结 2 1360
我在风中等你
我在风中等你 2021-02-06 02:42

I have a python class SceneGenerator which has multiple member functions for preprocessing and a generator function generate_data(). The basic structur

2条回答
  •  名媛妹妹
    2021-02-06 03:29

    Assuming you're using the latest Tensorflow (1.4 at the time of this writing), you can keep the generator and use the tf.data.* API as follows (I chose arbitrary values for the thread number, prefetch buffer size, batch size and output data types):

    NUM_THREADS = 5
    sceneGen = SceneGenerator()
    dataset = tf.data.Dataset.from_generator(sceneGen.generate_data, output_types=(tf.float32, tf.int32))
    dataset = dataset.map(lambda x,y : (x,y), num_parallel_calls=NUM_THREADS).prefetch(buffer_size=1000)
    dataset = dataset.batch(42)
    X, y = dataset.make_one_shot_iterator().get_next()
    

    To show that it's actually multiple threads extracting from the generator, I modified your class as follows:

    import threading    
    class SceneGenerator(object):
      def __init__(self):
        # some inits
        pass
    
      def generate_data(self):
        """
        Generator. Yield data X and labels y after some preprocessing
        """
        while True:
          # opening files, selecting data
          X,y = threading.get_ident(), 2 #self.preprocess(some_params, filenames, ...)            
          yield X, y
    

    This way, creating a Tensorflow session and getting one batch shows the thread IDs of the threads getting the data. On my pc, running:

    sess = tf.Session()
    print(sess.run([X, y]))
    

    prints

    [array([  8460.,   8460.,   8460.,  15912.,  16200.,  16200.,   8460.,
             15912.,  16200.,   8460.,  15912.,  16200.,  16200.,   8460.,
             15912.,  15912.,   8460.,   8460.,   6552.,  15912.,  15912.,
              8460.,   8460.,  15912.,   9956.,  16200.,   9956.,  16200.,
             15912.,  15912.,   9956.,  16200.,  15912.,  16200.,  16200.,
             16200.,   6552.,  16200.,  16200.,   9956.,   6552.,   6552.], dtype=float32),
     array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
            2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]
    

    Note: You might want to experiment removing the map call (that we only use to have the multiple threads) and checking if the prefetch's buffer is enough to remove the bottleneck in your input pipeline (even with only one thread, often the input preprocessing is faster than the actual graph execution, so the buffer is enough to have the preprocessing go as fast as it can).

提交回复
热议问题