问题
A recent similar question (isinstance(foo, types.GeneratorType) or inspect.isgenerator(foo)?) got me curious about how to implement this generically.
It seems like a generally-useful thing to have, actually, to have a generator-type object that will cache the first time through (like itertools.cycle
), report StopIteration, and then return items from the cache next time through, but if the object isn't a generator (i.e. a list or dict that inherently supports O(1) lookup), then don't cache, and have the same behaviour, but for the original list.
Possibilities:
1) Modify itertools.cycle. It looks like this:
def cycle(iterable):
saved = []
try:
saved.append(iterable.next())
yield saved[-1]
isiter = True
except:
saved = iterable
isiter = False
# cycle('ABCD') --> A B C D A B C D A B C D ...
for element in iterable:
yield element
if isiter:
saved.append(element)
# ??? What next?
If I could restart the generator, that would be perfect - I could send back a StopIteration, and then on the next gen.next(), return entry 0 i.e. `A B C D StopIteration A B C D StopIteration' but it doesn't look like that's actually possible.
Second would be that once StopIteration is hit, then saved has a cache. But it doesn't look like there's any way to get to the internal saved[] field. Maybe a class version of this?
2) Or I could pass in the list directly:
def cycle(iterable, saved=[]):
saved.clear()
try:
saved.append(iterable.next())
yield saved[-1]
isiter = True
except:
saved = iterable
isiter = False
# cycle('ABCD') --> A B C D A B C D A B C D ...
for element in iterable:
yield element
if isiter:
saved.append(element)
mysaved = []
myiter = cycle(someiter, mysaved)
But that just looks nasty. And in C/++ I could pass in some reference, and change the actual reference to saved to point to iterable - you can't actually do that in python. So this doesn't even work.
Other options?
Edit: More data. The CachingIterable method appears to be too slow to be effective, but it did push me in a direction that might work. It's slightly slower than the naive method (converting to list myself), but appears not to take the hit if it's already iterable.
Some code and data:
def cube_generator(max=100):
i = 0
while i < max:
yield i*i*i
i += 1
# Base case: use generator each time
%%timeit
cg = cube_generator(); [x for x in cg]
cg = cube_generator(); [x for x in cg]
cg = cube_generator(); [x for x in cg]
10000 loops, best of 3: 55.4 us per loop
# Fastest case: flatten to list, then iterate
%%timeit
cg = cube_generator()
cl = list(cg)
[x for x in cl]
[x for x in cl]
[x for x in cl]
10000 loops, best of 3: 27.4 us per loop
%%timeit
cg = cube_generator()
ci2 = CachingIterable(cg)
[x for x in ci2]
[x for x in ci2]
[x for x in ci2]
1000 loops, best of 3: 239 us per loop
# Another attempt, which is closer to the above
# Not exactly the original solution using next, but close enough i guess
class CacheGen(object):
def __init__(self, iterable):
if isinstance(iterable, (list, tuple, dict)):
self._myiter = iterable
else:
self._myiter = list(iterable)
def __iter__(self):
return self._myiter.__iter__()
def __contains__(self, key):
return self._myiter.__contains__(key)
def __getitem__(self, key):
return self._myiter.__getitem__(key)
%%timeit
cg = cube_generator()
ci = CacheGen(cg)
[x for x in ci]
[x for x in ci]
[x for x in ci]
10000 loops, best of 3: 30.5 us per loop
# But if you start with a list, it is faster
cg = cube_generator()
cl = list(cg)
%%timeit
[x for x in cl]
[x for x in cl]
[x for x in cl]
100000 loops, best of 3: 11.6 us per loop
%%timeit
ci = CacheGen(cl)
[x for x in ci]
[x for x in ci]
[x for x in ci]
100000 loops, best of 3: 13.5 us per loop
Any faster recipes that can get closer to the 'pure' loop?
回答1:
Based on this comment:
my intention here is that this would only be used if the user knows he wants to iterate multiple times over the 'iterable', but doesn't know if the input is a generator or iterable. this lets you ignore that distinction, while not losing (much) performance.
This simple solution does exactly that:
def ensure_list(it):
if isinstance(it, (list, tuple, dict)):
return it
else:
return list(it)
now ensure_list(a_list)
is practically a no-op - two function calls - while ensure_list(a_generator)
will turn it into a list and return it, which turned out to be faster than any other approach.
回答2:
What you want is not an iterator, but an iterable. An iterator can only iterate once through its contents. You want something which takes an iterator and over which you can then iterate multiple times, producing the same values from the iterator, even if the iterator doesn't remember them, like a generator. Then it's just a matter of special-casing those inputs which don't need caching. Here's a non-thread-safe example (EDIT: updated for efficiency):
import itertools
class AsYouGoCachingIterable(object):
def __init__(self, iterable):
self.iterable = iterable
self.iter = iter(iterable)
self.done = False
self.vals = []
def __iter__(self):
if self.done:
return iter(self.vals)
#chain vals so far & then gen the rest
return itertools.chain(self.vals, self._gen_iter())
def _gen_iter(self):
#gen new vals, appending as it goes
for new_val in self.iter:
self.vals.append(new_val)
yield new_val
self.done = True
And some timings:
class ListCachingIterable(object):
def __init__(self, obj):
self.vals = list(obj)
def __iter__(self):
return iter(self.vals)
def cube_generator(max=1000):
i = 0
while i < max:
yield i*i*i
i += 1
def runit(iterable_factory):
for i in xrange(5):
for what in iterable_factory():
pass
def puregen():
runit(lambda: cube_generator())
def listtheniter():
res = list(cube_generator())
runit(lambda: res)
def listcachingiterable():
res = ListCachingIterable(cube_generator())
runit(lambda: res)
def asyougocachingiterable():
res = AsYouGoCachingIterable(cube_generator())
runit(lambda: res)
Results are:
In [59]: %timeit puregen()
1000 loops, best of 3: 774 us per loop
In [60]: %timeit listtheniter()
1000 loops, best of 3: 345 us per loop
In [61]: %timeit listcachingiterable()
1000 loops, best of 3: 348 us per loop
In [62]: %timeit asyougocachingiterable()
1000 loops, best of 3: 630 us per loop
So the simplest approach in terms of a class, ListCachingIterable
, works just about as well as doing the list
manually. The "as-you-go" variant is almost twice as slow, but has advantages if you don't consume the entire list, e.g. say you're only looking for the first cube over 100:
def first_cube_past_100(cubes):
for cube in cubes:
if cube > 100:
return cube
raise Error("No cube > 100 in this iterable")
Then:
In [76]: %timeit first_cube_past_100(cube_generator())
100000 loops, best of 3: 2.92 us per loop
In [77]: %timeit first_cube_past_100(ListCachingIterable(cube_generator()))
1000 loops, best of 3: 255 us per loop
In [78]: %timeit first_cube_past_100(AsYouGoCachingIterable(cube_generator()))
100000 loops, best of 3: 10.2 us per loop
来源:https://stackoverflow.com/questions/19503455/caching-a-generator