Caching a generator

前端 未结 2 836
忘了有多久
忘了有多久 2021-01-05 04:02

A recent similar question (isinstance(foo, types.GeneratorType) or inspect.isgenerator(foo)?) got me curious about how to implement this generically.

It seems like

相关标签:
2条回答
  • 2021-01-05 04:36

    Based on this comment:

    my intention here is that this would only be used if the user knows he wants to iterate multiple times over the 'iterable', but doesn't know if the input is a generator or iterable. this lets you ignore that distinction, while not losing (much) performance.

    This simple solution does exactly that:

    def ensure_list(it):
        if isinstance(it, (list, tuple, dict)):
            return it
        else:
            return list(it)
    

    now ensure_list(a_list) is practically a no-op - two function calls - while ensure_list(a_generator) will turn it into a list and return it, which turned out to be faster than any other approach.

    0 讨论(0)
  • 2021-01-05 04:42

    What you want is not an iterator, but an iterable. An iterator can only iterate once through its contents. You want something which takes an iterator and over which you can then iterate multiple times, producing the same values from the iterator, even if the iterator doesn't remember them, like a generator. Then it's just a matter of special-casing those inputs which don't need caching. Here's a non-thread-safe example (EDIT: updated for efficiency):

    import itertools
    class AsYouGoCachingIterable(object):
        def __init__(self, iterable):
            self.iterable = iterable
            self.iter = iter(iterable)
            self.done = False
            self.vals = []
    
        def __iter__(self):
            if self.done:
                return iter(self.vals)
            #chain vals so far & then gen the rest
            return itertools.chain(self.vals, self._gen_iter())
    
        def _gen_iter(self):
            #gen new vals, appending as it goes
            for new_val in self.iter:
                self.vals.append(new_val)
                yield new_val
            self.done = True
    

    And some timings:

    class ListCachingIterable(object):
        def __init__(self, obj):
            self.vals = list(obj)
    
        def __iter__(self):
            return iter(self.vals)
    
    def cube_generator(max=1000):
        i = 0
        while i < max:
            yield i*i*i
            i += 1
    
    def runit(iterable_factory):
        for i in xrange(5):
            for what in iterable_factory():
                pass
    
    def puregen():
        runit(lambda: cube_generator())
    def listtheniter():
        res = list(cube_generator())
        runit(lambda: res)
    def listcachingiterable():
        res = ListCachingIterable(cube_generator())
        runit(lambda: res)
    def asyougocachingiterable():
        res = AsYouGoCachingIterable(cube_generator())
        runit(lambda: res)
    

    Results are:

    In [59]: %timeit puregen()
    1000 loops, best of 3: 774 us per loop
    
    In [60]: %timeit listtheniter()
    1000 loops, best of 3: 345 us per loop
    
    In [61]: %timeit listcachingiterable()
    1000 loops, best of 3: 348 us per loop
    
    In [62]: %timeit asyougocachingiterable()
    1000 loops, best of 3: 630 us per loop
    

    So the simplest approach in terms of a class, ListCachingIterable, works just about as well as doing the list manually. The "as-you-go" variant is almost twice as slow, but has advantages if you don't consume the entire list, e.g. say you're only looking for the first cube over 100:

    def first_cube_past_100(cubes):
        for cube in cubes:
            if cube > 100:
                return cube
        raise Error("No cube > 100 in this iterable")
    

    Then:

    In [76]: %timeit first_cube_past_100(cube_generator())
    100000 loops, best of 3: 2.92 us per loop
    
    In [77]: %timeit first_cube_past_100(ListCachingIterable(cube_generator()))
    1000 loops, best of 3: 255 us per loop
    
    In [78]: %timeit first_cube_past_100(AsYouGoCachingIterable(cube_generator()))
    100000 loops, best of 3: 10.2 us per loop
    
    0 讨论(0)
提交回复
热议问题