Here is a seemingly simple problem: given a list of iterators that yield sequences of integers in ascending order, write a concise generator that yields only the integers that appear in every sequence.
After reading a few papers last night, I decided to hack up a completely minimal full text indexer in Python, as seen here (though that version is quite old now).
My problem is with the search()
function, which must iterate over each posting list and yield only the document IDs that appear on every list. As you can see from the link above, my current non-recursive 'working' attempt is terrible.
Example:
postings = [[1, 100, 142, 322, 12312],
[2, 100, 101, 322, 1221],
[100, 142, 322, 956, 1222]]
Should yield:
[100, 322]
There is at least one elegant recursive function solution to this, but I'd like to avoid that if possible. However, a solution involving nested generator expressions, itertools
abuse, or any other kind of code golf is more than welcome. :-)
It should be possible to arrange for the function to only require as many steps as there are items in the smallest list, and without sucking the entire set of integers into memory. In future, these lists may be read from disk, and larger than available RAM.
For the past 30 minutes I've had an idea on the tip of my tongue, but I can't quite get it into code. Remember, this is just for fun!
import heapq, itertools
def intersect(*its):
for key, values in itertools.groupby(heapq.merge(*its)):
if len(list(values)) == len(its):
yield key
>>> list(intersect(*postings))
[100, 322]
def postings(posts):
sets = (set(l) for l in posts)
return sorted(reduce(set.intersection, sets))
... you could try and take advantage of the fact that the lists are ordered, but since reduce, generator expressions and set are all implemented in C, you'll probably have a hard time doing better than the above with logic implemented in python.
This solution will compute the intersection of your iterators. It works by advancing the iterators one step at a time and looking for the same value in all of them. When found, such values are yielded -- this makes the intersect
function a generator itself.
import operator
def intersect(sequences):
"""Compute intersection of sequences of increasing integers.
>>> list(intersect([[1, 100, 142, 322, 12312],
... [2, 100, 101, 322, 1221],
... [100, 142, 322, 956, 1222]]))
[100, 322]
"""
iterators = [iter(seq) for seq in sequences]
last = [iterator.next() for iterator in iterators]
indices = range(len(iterators) - 1)
while True:
# The while loop stops when StopIteration is raised. The
# exception will also stop the iteration by our caller.
if reduce(operator.and_, [l == last[0] for l in last]):
# All iterators contain last[0]
yield last[0]
last = [iterator.next() for iterator in iterators]
# Now go over the iterators once and advance them as
# necessary. To stop as soon as the smallest iterator is
# exhausted we advance each iterator only once per iteration
# in the while loop.
for i in indices:
if last[i] < last[i+1]:
last[i] = iterators[i].next()
if last[i] > last[i+1]:
last[i+1] = iterators[i+1].next()
If these are really long (or even infinite) sequences, and you don't want to load everything into a set in advance, you can implement this with a 1-item lookahead on each iterator.
EndOfIter = object() # Sentinel value
class PeekableIterator(object):
def __init__(self, it):
self.it = it
self._peek = None
self.next() # pump iterator to get first value
def __iter__(self): return self
def next(self):
cur = self._peek
if cur is EndOfIter:
raise StopIteration()
try:
self._peek = self.it.next()
except StopIteration:
self._peek = EndOfIter
return cur
def peek(self):
return self._peek
def contained_in_all(seqs):
if not seqs: return # No items
iterators = [PeekableIterator(iter(seq)) for seq in seqs]
first, rest = iterators[0], iterators[1:]
for item in first:
candidates = list(rest)
while candidates:
if any(c.peek() is EndOfIter for c in candidates): return # Exhausted an iterator
candidates = [c for c in candidates if c.peek() < item]
for c in candidates: c.next()
# Out of loop if first item in remaining iterator are all >= item.
if all(it.peek() == item for it in rest):
yield item
Usage:
>>> print list(contained_in_all(postings))
[100, 322]
What about this:
import heapq
def inalliters(iterators):
heap=[(iterator.next(),iterator) for iterator in iterators]
heapq.heapify(heap)
maximal = max(heap)[0]
while True:
value,iterator = heapq.heappop(heap)
if maximal==value: yield value
nextvalue=iterator.next()
heapq.heappush(heap,(nextvalue,iterator))
maximal=max(maximal,nextvalue)
postings = [iter([1, 100, 142, 322, 12312]),
iter([2, 100, 101, 322, 1221]),
iter([100, 142, 322, 956, 1222])]
print [x for x in inalliters(postings)]
I haven't tested it very thoroughly (just ran your example), but I believe the basic idea is sound.
I want to show that there's an elegant solution, which only iterates forward once. Sorry, I don't know the Python well enough, so I use fictional classes. This one reads input
, an array of iterators, and writes to output
on-the-fly without ever going back or using any array function!.
def intersect (input, output)
do:
min = input[0]
bingo = True
for i in input:
if (i.cur < min.cur):
bingo = False
min = i
if bingo:
output.push(min.cur)
while (min.step())
This one runs in O(n*m)
where n
is the sum of all iterator lengths, and m
is the number of lists. It can be made O(n*logm)
by using a heap in line 6.
def intersection(its):
if not its: return
vs = [next(it) for it in its]
m = max(vs)
while True:
v, i = min((v,i) for i,v in enumerate(vs))
if v == m:
yield m
vs[i] = next(its[i])
m = max(m, vs[i])
来源:https://stackoverflow.com/questions/969709/joining-a-set-of-ordered-integer-yielding-python-iterators