How to write a caffe python data layer with preload?

前端 未结 1 980
醉话见心
醉话见心 2021-01-03 08:49

How to write an asyncronous data layer to preload batches while other processing is performed? Are there some example codes? Thanks

1条回答
  •  清酒与你
    2021-01-03 09:00

    There are several ways you can achieve what you want. I'll try and sketch one option here.

    The overall view of the system is: you have n Loaders asynchronously loading data and feeding a queue. The layer then reads batch_size items from the queue and feed the net in the forward() function.

    import caffe, multiprocessing
    
    class Loader(multiprocessing.Process):
      def __init__(self, outq, *args, **kwargs):
        super(Loader, self).__init__()
        self.daemon = True
        self.outq = outq
        self.start()  # start working
    
      def run(self):
        while True:  # read and never stop at all!
          try:
            # do your magic here
            # assuming you load x,y pairs
            self.outq.put((x[None, ...], y[None, ...]))  # add singleton "batch" dimension
          except Exception as e:
            # handle errors?
            pass
    
     class MultiProcessInputLayer(caffe.Layer):
       def setup(self, bottom, top):
         # verify no bottoms, right number of tops etc.
         self.dataQ = multiprocessing.Queue()
         for _ in xrange(n):
           Loader(self.dataQ)  # start n Loaders
         # some other stuff here...
    
       def reshape(self, bottom, top):
         # reshape the inputs to the right sizes
    
       def forward(self, bottom, top):
         for i in xrange(batch_size):
           item = self.dataQ.get()
           top[0].data[i, ...] = item[0]
           top[1].data[i, ...] = item[1]
    
       def backward(self, top, propagate_down, bottom):
         pass  # no backward for data layer
    

    Some tips and tricks I learned the hard way:
    1. Use multiprocessing and not threading package because of the GIL.
    2. Sometimes (e.g. if batch_size is very large) it will take very long for forward() to read item by item from the Queue to form each batch. In that case, you might add another multiprocessing.Process that will async read batch_size items from self.dataQ and write whole batches to self.batchQ. Then forward() will only wait for a single item from self.batchQ at each call.
    3. Take care not to copy the data around too much. Working with large images/labels can make all these copying into a bottleneck.

    0 讨论(0)
提交回复
热议问题