Given a list of arbitrairly deep nested lists of arbitrary size, I would like an flat, depth-first iterator over all elements in the tree, but with path indicies as well such th
I think your own solution is alright, that there's nothing really simpler, and that Python's standard library stuff wouldn't help. But here's another way anyway, which works iteratively instead of recursively so it can handle very deeply nested lists.
def flatten(l):
stack = [enumerate(l)]
path = [None]
while stack:
for path[-1], x in stack[-1]:
if isinstance(x, list):
stack.append(enumerate(x))
path.append(None)
else:
yield x, tuple(path)
break
else:
stack.pop()
path.pop()
I keep the currently "active" lists on a stack of enumerate
iterators, and the current index path as another stack. Then in a while-loop I always try to take the next element from the current list and handle it appropriately:
enumerate
iterator on the stack and make room for the deeper index on the index path stack.Demo:
>>> L = [[[1, 2, 3], [4, 5]], [6], [7,[8,9]], 10]
>>> for entry in flatten(L):
print(entry)
(1, (0, 0, 0))
(2, (0, 0, 1))
(3, (0, 0, 2))
(4, (0, 1, 0))
(5, (0, 1, 1))
(6, (1, 0))
(7, (2, 0))
(8, (2, 1, 0))
(9, (2, 1, 1))
(10, (3,))
Note that if you process the entries on the fly, like printing does, then you could just yield the path as the list it is, i.e., use yield x, path
. Demo:
>>> for entry in flatten(L):
print(entry)
(1, [0, 0, 0])
(2, [0, 0, 1])
(3, [0, 0, 2])
(4, [0, 1, 0])
(5, [0, 1, 1])
(6, [1, 0])
(7, [2, 0])
(8, [2, 1, 0])
(9, [2, 1, 1])
(10, [3])
This way, the iterator only takes O(n) time for the whole iteration, where n is the total number of objects in the structure (both lists and numbers). Of course the printing increases the complexity, just like creating the tuples does. But that's then outside of the generator and the "fault" of the printing or whatever you're doing with each path. If you for example only look at each path's length instead of its contents, which takes O(1), then the whole thing even actually is O(n).
All that said, again, I think your own solution is alright. And clearly simpler than this. And like I commented under @naomik's answer, I think your solution not being able to handle lists of depth around 1000 or more is rather irrelevant. One should not even have such a list in the first place. If one does, that's a mistake that should be fixed instead. If the list can also go wide, as in your case, and is balanced, then even with a branch factor of just 2 you'd run out of memory at a depth well under 100 and you wouldn't get anywhere near 1000. If the list can't go wide, then nested lists is the wrong data structure choice, plus you wouldn't be interested in the index path in the first place. If it can go wide but doesn't, then I'd say the creation algorithm should be improved (for example if it represents a sorted tree, add balancing).
About my solution again: Besides its ability to handle arbitrarily deep lists and its efficiency, I find some of its details interesting to note:
enumerate
objects being stored somewhere. Usually they're just used in loops&Co directly, like for i, x in enumerate(l):
.path[-1]
spot ready and writing into it with for path[-1], x in ...
.for
-loop with an immediate break
and an else
branch, to iterate over the next single value and handle ends gracefully without try
/except
and without next
and some default.yield x, path
, i.e., don't turn each path into a tuple, then you really need to process it directly during the iteration. For example if you do list(flatten(L))
, then you get [(1, []), (2, []), (3, []), (4, []), (5, []), (6, []), (7, []), (8, []), (9, []), (10, [])]
. That is, "all" index paths will be empty. Of course that's because there really only is one path object which I update and yield over and over again, and in the end its empty. This is very similar to itertools.groupby
, where for example [list(g) for _, g in list(groupby('aaabbbb'))]
gives you [[], ['b']]
. And it's not a bad thing. I recently wrote about that extensively.Shorter version with one stack holding both indexes and enumerate
objects alternatingly:
def flatten(l):
stack = [None, enumerate(l)]
while stack:
for stack[-2], x in stack[-1]:
if isinstance(x, list):
stack += None, enumerate(x)
else:
yield x, stack[::2]
break
else:
del stack[-2:]