parallelising tf.data.Dataset.from_generator

后端 未结 3 792
孤独总比滥情好
孤独总比滥情好 2020-12-01 01:24

I have a non trivial input pipeline that from_generator is perfect for...

dataset = tf.data.Dataset.from         


        
相关标签:
3条回答
  • 2020-12-01 01:56

    I am working on a from_indexable for tf.data.Dataset https://github.com/tensorflow/tensorflow/issues/14448

    The advantage for from_indexable is that it can be parallelized, while a python generator cannot be parallelized.

    The function from_indexable makes a tf.data.range, wraps the indexable in a generalized tf.py_func and calls map.

    For those that want now a from_indexable, here the lib code

    import tensorflow as tf
    import numpy as np
    
    from tensorflow.python.framework import tensor_shape
    from tensorflow.python.util import nest
    
    def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
        def decorator(func):
            def call(*args):
                nonlocal output_shapes
    
                flat_output_types = nest.flatten(output_types)
                flat_values = tf.py_func(
                    func, 
                    inp=args, 
                    Tout=flat_output_types,
                    stateful=stateful, name=name
                )
                if output_shapes is not None:
                    # I am not sure if this is nessesary
                    output_shapes = nest.map_structure_up_to(
                        output_types, tensor_shape.as_shape, output_shapes)
                    flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                    for ret_t, shape in zip(flat_values, flattened_shapes):
                        ret_t.set_shape(shape)
                return nest.pack_sequence_as(output_types, flat_values)
            return call
        return decorator
    
    def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
        ds = tf.data.Dataset.range(len(iterator))
        @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
        def index_to_entry(index):
            return iterator[index]    
        return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
    

    and here an example (Note: from_indexable has a num_parallel_calls argument)

    class PyDataSet:
        def __len__(self):
            return 20
    
        def __getitem__(self, item):
            return np.random.normal(size=(item+1, 10))
    
    ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
    it = ds.make_one_shot_iterator()
    entry = it.get_next()
    with tf.Session() as sess:
        print(sess.run(entry).shape)
        print(sess.run(entry).shape)
    

    Update June 10, 2018: Since https://github.com/tensorflow/tensorflow/pull/15121 is merged, the code for from_indexable simplifies to:

    import tensorflow as tf
    
    def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
        def decorator(func):
            def call(*args, **kwargs):
                return tf.contrib.framework.py_func(
                    func=func, 
                    args=args, kwargs=kwargs, 
                    output_types=output_types, output_shapes=output_shapes, 
                    stateful=stateful, name=name
                )
            return call
        return decorator
    
    def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
        ds = tf.data.Dataset.range(len(iterator))
        @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
        def index_to_entry(index):
            return iterator[index]    
        return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
    
    0 讨论(0)
  • 2020-12-01 02:09

    Turns out I can use Dataset.map if I make the generator super lightweight (only generating meta data) and then move the actual heavy lighting into a stateless function. This way I can parallelise just the heavy lifting part with .map using a py_func.

    Works; but feels a tad clumsy... Would be great to be able to just add num_parallel_calls to from_generator :)

    def pure_numpy_and_pil_complex_calculation(metadata, label):
      # some complex pil and numpy work nothing to do with tf
      ...
    
    dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                             output_types=(tf.string,   # metadata
                                                           tf.string))  # label
    
    def wrapped_complex_calulation(metadata, label):
      return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                        inp = (metadata, label),
                        Tout = (tf.uint8,    # (H,W,3) img
                                tf.string))  # label
    dataset = dataset.map(wrapped_complex_calulation,
                          num_parallel_calls=8)
    
    dataset = dataset.batch(64)
    iter = dataset.make_one_shot_iterator()
    imgs, labels = iter.get_next()
    
    0 讨论(0)
  • 2020-12-01 02:16

    Limiting the work done in the generator to a minimum and parallelizing the expensive processing using a map is sensible.

    Alternatively, you can "join" multiple generators using parallel_interleave as follows:

    def generator(n):
      # returns n-th generator function
    
    def dataset(n):
      return tf.data.Dataset.from_generator(generator(n))
    
    ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))
    
    # where N is the number of generators you use
    
    0 讨论(0)
提交回复
热议问题