Given a list of arbitrairly deep nested lists of arbitrary size, I would like an flat, depth-first iterator over all elements in the tree, but with path indicies as well such th
Starting with direct recursion and state variables with default values,
def flatten (l, i = 0, path = (), acc = []):
if not l:
return acc
else:
first, *rest = l
if isinstance (first, list):
return flatten (first, 0, path + (i,), acc) + flatten (rest, i + 1, path, [])
else:
return flatten (rest, i + 1, path, acc + [ (first, path + (i,)) ])
print (flatten (L))
# [ (1, (0, 0, 0))
# , (2, (0, 0, 1))
# , (3, (0, 0, 2))
# , (4, (0, 1, 0))
# , (5, (0, 1, 1))
# , (6, (1, 0))
# , (7, (2, 0))
# , (8, (2, 1, 0))
# , (9, (2, 1, 1))
# , (10, (3,))
# ]
The program above shares the same weakness as yours; it is not safe for deep lists. We can use continuation-passing style to make it tail recursive – changes in bold
def identity (x):
return x
# tail-recursive, but still not stack-safe, yet
def flatten (l, i = 0, path = (), acc = [], cont = identity):
if not l:
return cont (acc)
else:
first, *rest = l
if isinstance (first, list):
return flatten (first, 0, path + (i,), acc, lambda left:
flatten (rest, i + 1, path, [], lambda right:
cont (left + right)))
else:
return flatten (rest, i + 1, path, acc + [ (first, path + (i,)) ], cont)
print (flatten (L))
# [ (1, (0, 0, 0))
# , (2, (0, 0, 1))
# , (3, (0, 0, 2))
# , (4, (0, 1, 0))
# , (5, (0, 1, 1))
# , (6, (1, 0))
# , (7, (2, 0))
# , (8, (2, 1, 0))
# , (9, (2, 1, 1))
# , (10, (3,))
# ]
Finally, we replace the recursive calls with our own call
mechanism. This effectively sequences the recursive calls and now it works for data of any size and any level of nesting. This technique is called a trampoline – changes in bold
def identity (x):
return x
def flatten (l):
def loop (l, i = 0, path = (), acc = [], cont = identity):
if not l:
return cont (acc)
else:
first, *rest = l
if isinstance (first, list):
return call (loop, first, 0, path + (i,), acc, lambda left:
call (loop, rest, i + 1, path, [], lambda right:
cont (left + right)))
else:
return call (loop, rest, i + 1, path, acc + [ (first, path + (i,)) ], cont)
return loop (l) .run ()
class call:
def __init__ (self, f, *xs):
self.f = f
self.xs = xs
def run (self):
acc = self
while (isinstance (acc, call)):
acc = acc.f (*acc.xs)
return acc
print (flatten (L))
# [ (1, (0, 0, 0))
# , (2, (0, 0, 1))
# , (3, (0, 0, 2))
# , (4, (0, 1, 0))
# , (5, (0, 1, 1))
# , (6, (1, 0))
# , (7, (2, 0))
# , (8, (2, 1, 0))
# , (9, (2, 1, 1))
# , (10, (3,))
# ]
Why is it better? Objectively speaking, it's a more complete program. Just because it appears more complex doesn't mean it is less efficient.
The code provided in the question fails when the input list is nested more then 996 levels deep (in python 3.x)
depth = 1000
L = [1]
while (depth > 0):
L = [L]
depth = depth - 1
for x in flatten (L):
print (x)
# Bug in the question's code:
# the first value in the tuple is not completely flattened
# ([[[[[1]]]]], (0, 0, 0, ... ))
Worse, when depth
increases to around 2000, the code provided in the question generates a run time error GeneratorExitException
.
When using my program, it works for inputs of any size, nested to any depth, and always produces the correct output.
depth = 50000
L = [1]
while (depth > 0):
L = [L]
depth = depth - 1
print (flatten (L))
# (1, (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49990 more...))
print (flatten (range (50000)))
# [ (0, (0,))
# , (1, (1,))
# , (2, (2,))
# , ...
# , (49999, (49999,))
# ]
Who would have such a deep list anyway? One such common case is the linked list which creates deep, tree-like structures
my_list = [ 1, [ 2, [ 3, [ 4, None ] ] ] ]
Such a structure is common because the the outermost pair gives us easy access to the two semantic parts we care about: the first item, and the rest of the items. The linked list could be implemented using tuple or dict as well.
my_list = ( 1, ( 2, ( 3, ( 4, None ) ) ) )
my_list = { "first": 1
, "rest": { "first": 2
, "rest": { "first": 3
, "rest": { "first": 4
, "rest": None
}
}
}
}
Above, we can see that a sensible structure potentially creates a significant depth. In Python, []
, ()
, and {}
allow you to nest infinitely. Why should our generic flatten
restrict that freedom?
It's my opinion that if you're going to design a generic function like flatten
, we should choose the implementation that works in the most cases and has the fewest surprises. One that suddenly fails just because a certain (deep) structure is used is bad. The flatten
used in my answer is not the fastest[1], but it doesn't surprise the programmer with strange answers or program crashes.
[1] I don't measure performance until it matters, and so I haven't done anything to tune flatten
above. Another understated advantage of my program is that you can tune it because we wrote it – On the other hand, if for
, enumerate
and yield
caused problems in your program, what would you do to "fix" it? How would we make it faster? How would we make it work for inputs of greater size or depth? What good is a Ferrari after it wrapped around a tree?
I think your own solution is alright, that there's nothing really simpler, and that Python's standard library stuff wouldn't help. But here's another way anyway, which works iteratively instead of recursively so it can handle very deeply nested lists.
def flatten(l):
stack = [enumerate(l)]
path = [None]
while stack:
for path[-1], x in stack[-1]:
if isinstance(x, list):
stack.append(enumerate(x))
path.append(None)
else:
yield x, tuple(path)
break
else:
stack.pop()
path.pop()
I keep the currently "active" lists on a stack of enumerate
iterators, and the current index path as another stack. Then in a while-loop I always try to take the next element from the current list and handle it appropriately:
enumerate
iterator on the stack and make room for the deeper index on the index path stack.Demo:
>>> L = [[[1, 2, 3], [4, 5]], [6], [7,[8,9]], 10]
>>> for entry in flatten(L):
print(entry)
(1, (0, 0, 0))
(2, (0, 0, 1))
(3, (0, 0, 2))
(4, (0, 1, 0))
(5, (0, 1, 1))
(6, (1, 0))
(7, (2, 0))
(8, (2, 1, 0))
(9, (2, 1, 1))
(10, (3,))
Note that if you process the entries on the fly, like printing does, then you could just yield the path as the list it is, i.e., use yield x, path
. Demo:
>>> for entry in flatten(L):
print(entry)
(1, [0, 0, 0])
(2, [0, 0, 1])
(3, [0, 0, 2])
(4, [0, 1, 0])
(5, [0, 1, 1])
(6, [1, 0])
(7, [2, 0])
(8, [2, 1, 0])
(9, [2, 1, 1])
(10, [3])
This way, the iterator only takes O(n) time for the whole iteration, where n is the total number of objects in the structure (both lists and numbers). Of course the printing increases the complexity, just like creating the tuples does. But that's then outside of the generator and the "fault" of the printing or whatever you're doing with each path. If you for example only look at each path's length instead of its contents, which takes O(1), then the whole thing even actually is O(n).
All that said, again, I think your own solution is alright. And clearly simpler than this. And like I commented under @naomik's answer, I think your solution not being able to handle lists of depth around 1000 or more is rather irrelevant. One should not even have such a list in the first place. If one does, that's a mistake that should be fixed instead. If the list can also go wide, as in your case, and is balanced, then even with a branch factor of just 2 you'd run out of memory at a depth well under 100 and you wouldn't get anywhere near 1000. If the list can't go wide, then nested lists is the wrong data structure choice, plus you wouldn't be interested in the index path in the first place. If it can go wide but doesn't, then I'd say the creation algorithm should be improved (for example if it represents a sorted tree, add balancing).
About my solution again: Besides its ability to handle arbitrarily deep lists and its efficiency, I find some of its details interesting to note:
enumerate
objects being stored somewhere. Usually they're just used in loops&Co directly, like for i, x in enumerate(l):
.path[-1]
spot ready and writing into it with for path[-1], x in ...
.for
-loop with an immediate break
and an else
branch, to iterate over the next single value and handle ends gracefully without try
/except
and without next
and some default.yield x, path
, i.e., don't turn each path into a tuple, then you really need to process it directly during the iteration. For example if you do list(flatten(L))
, then you get [(1, []), (2, []), (3, []), (4, []), (5, []), (6, []), (7, []), (8, []), (9, []), (10, [])]
. That is, "all" index paths will be empty. Of course that's because there really only is one path object which I update and yield over and over again, and in the end its empty. This is very similar to itertools.groupby
, where for example [list(g) for _, g in list(groupby('aaabbbb'))]
gives you [[], ['b']]
. And it's not a bad thing. I recently wrote about that extensively.Shorter version with one stack holding both indexes and enumerate
objects alternatingly:
def flatten(l):
stack = [None, enumerate(l)]
while stack:
for stack[-2], x in stack[-1]:
if isinstance(x, list):
stack += None, enumerate(x)
else:
yield x, stack[::2]
break
else:
del stack[-2:]
Recursion is good approach for flattening deeply nested lists. Your implementation is also well done. I would suggest modifying it with this similar recipe as follows:
Code
from collections import Iterable
def indexed_flatten(items):
"""Yield items from any nested iterable."""
for i, item in enumerate(items):
if isinstance(item, Iterable) and not isinstance(item, (str, bytes)):
for item_, idx in indexed_flatten(item):
yield item_, (i,) + idx
else:
yield item, (i,)
lst = [[[1, 2, 3], [4, 5]], [6], [7, [8, 9]], 10]
list(indexed_flatten(lst))
Output
[(1, (0, 0, 0)),
(2, (0, 0, 1)),
(3, (0, 0, 2)),
(4, (0, 1, 0)),
(5, (0, 1, 1)),
(6, (1, 0)),
(7, (2, 0)),
(8, (2, 1, 0)),
(9, (2, 1, 1)),
(10, (3,))]
This robustly works with many item types, e.g. [[[1, 2, 3], {4, 5}], [6], (7, [8, "9"]), 10]
.