Is the class generator (inheriting Sequence) thread safe in Keras/Tensorflow?

后端 未结 2 807
面向向阳花
面向向阳花 2021-02-06 03:11

For making the training of a model faster, it seems to be a good practice to populate/generate batches on CPU and run the training of the model on GPU in parallel. For this purp

2条回答
  •  臣服心动
    2021-02-06 04:01

    Among those who have seen this post, no one seems to have the ultimate answer so that I wanted to give my answer that worked out for me. Because of lack of documentation in the domain, my answer might be missing some relevant details. Please feel free to add more information that I do not mention down here.

    Seemingly, writing a generator class in Python that inherits the Sequence class is just not supported in Windows. (You can seemingly make it work on Linux.) To be able to make it work, you need to set the parameter use_multiprocessing=True (with the class approach). But it is not working on Windows as mentioned so that you have to set use_multiprocessing to False (on Windows). Nevertheless, that does not mean that multiprocessing does not work on Windows. Even if you set use_multiprocessing=False, multiprocessing can still be supported when the code is run with the following setup where you just set the workers parameter to any value that is bigger than 1.

    Example:

    history = \
       merged_model.fit_generator(generator=train_generator,
                                  steps_per_epoch=trainset_steps_per_epoch,
                                  epochs=300,
                                  verbose=1,
                                  use_multiprocessing=False,
                                  workers=3,
                                  max_queue_size=4)
    

    At this point, let's remember the Keras documentation again:

    The use of keras.utils.Sequence guarantees the ordering and guarantees the single use of every input per epoch when using use_multiprocessing=True.

    To my understanding, if use_multiprocessing=False, then the generator is not thread safe anymore, which makes it difficult to write a generator class that inherits Sequence.

    To come around this problem, I have written a generator myself which I have made thread safe manually. Here is an example pseudocode:

    import tensorflow as tf
    import threading
    
    class threadsafe_iter:
        """Takes an iterator/generator and makes it thread-safe by
        serializing call to the `next` method of given iterator/generator.
        """
        def __init__(self, it):
            self.it = it
            self.lock = threading.Lock()
    
        def __iter__(self):
            return self
    
        def __next__(self): # Py3
            return next(self.it)
    
        #def next(self):     # Python2 only
        #    with self.lock:
        #        return self.it.next()
    
    def threadsafe_generator(f):
        """A decorator that takes a generator function and makes it thread-safe.
        """
        def g(*a, **kw):
            return threadsafe_iter(f(*a, **kw))
        return g
    
    
    @threadsafe_generator
    def generate_data(tfrecord_file_path_list, ...):
    
        dataset = tf.data.TFRecordDataset(tfrecord_file_path_list)
    
        # example proto decode
        def _parse_function(example_proto):
          ...
          return batch_data
    
        # Parse the record into tensors.
        dataset = dataset.map(_parse_function)  
    
        dataset = dataset.shuffle(buffer_size=100000)
    
        # Repeat the input indefinitly
        dataset = dataset.repeat()  
    
        # Generate batches
        dataset = dataset.batch(batch_size)
    
        # Create an initializable iterator
        iterator = dataset.make_initializable_iterator()
    
        # Get batch data
        batch_data = iterator.get_next()
    
        iterator_init_op = iterator.make_initializer(dataset)
    
        with tf.Session() as sess:
    
            sess.run(iterator_init_op)
    
            while True:            
                try:
                    batch_data = sess.run(batch_data)
                except tf.errors.OutOfRangeError:
                    break
                yield batch_data
    

    Well, it can be discussed if it is really elegant to do it in this way but it seems to be working pretty well.

    To summarize:

    • If writing your program on Windows, set use_multiprocessing to False.
    • (As of today, to my knowledge) it is not supported to write a generator class that inherits Sequence when writing code on Windows. (It is a Tensorflow/Keras problem I guess).
    • To come around the problem, write an ordinary generator, make your generator thread safe, and set workers to a number that is greater than 1.

    Important note: In this setup, the generator is being run on CPU and the training is being done on GPU. One problem I could observe is that if the model you are training is shallow enough, the utilization of GPU remains very low while CPU utilization gets high. If the model is shallow and the dataset is small enough, it can be a good option to store all the data in the memory and run everything on GPU. It should speed up the training significantly. If, for any reason, you would like to use CPU and GPU simultaneously, my modest recommendation is to try to use Tensorflow's tf.data API which significantly speeds up the data preprocessing and batch preparation. If the generator is only written in Python, GPU keeps waiting for data to continue with the training. One can say everything about the Tensorflow/Keras documentation, but it is really efficient code!

    Anyone having more complete knowledge on the API and seeing this post, please feel free to correct me here in case I misunerstand anything or the API is updated to solve the problems even on Windows.

提交回复
热议问题