Tensorflow: Load data in multiple threads on cpu

后端 未结 2 1359
我在风中等你
我在风中等你 2021-02-06 02:42

I have a python class SceneGenerator which has multiple member functions for preprocessing and a generator function generate_data(). The basic structur

2条回答
  •  无人及你
    2021-02-06 03:34

    Running a session with a feed_dict is indeed pretty slow:

    Feed_dict does a single-threaded memcpy of contents from Python runtime into TensorFlow runtime.

    A faster way to feed the data is by using tf.train.string_input_producer + *Reader + tf.train.Coordinator, which will batch the data in multiple threads. For that, you read the data directly into tensors, e.g., here's a way to read and process a csv file:

    def batch_generator(filenames):
      filename_queue = tf.train.string_input_producer(filenames)
      reader = tf.TextLineReader(skip_header_lines=1)
      _, value = reader.read(filename_queue)
    
      content = tf.decode_csv(value, record_defaults=record_defaults)
      content[4] = tf.cond(tf.equal(content[4], tf.constant('Present')),
                           lambda: tf.constant(1.0),
                           lambda: tf.constant(0.0))
    
      features = tf.stack(content[:N_FEATURES])
      label = content[-1]
    
      data_batch, label_batch = tf.train.shuffle_batch([features, label],
                                                       batch_size=BATCH_SIZE,
                                                       capacity=20*BATCH_SIZE,
                                                       min_after_dequeue=10*BATCH_SIZE)
      return data_batch, label_batch
    

    This function gets the list of input files, creates the reader and data transformations and outputs the tensors, which are evaluated to the contents of these files. Your scene generator is likely to do different transformations, but the idea is the same.

    Next, you start a tf.train.Coordinator to parallelize this:

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in range(10):  # generate 10 batches
            features, labels = sess.run([data_batch, label_batch])
            print(features)
        coord.request_stop()
        coord.join(threads)
    

    In my experience, this way feeds the data much faster and allows to utilize the whole available GPU power. Complete working example can be found here.

提交回复
热议问题