Python: flatten nested lists with indices

后端 未结 3 965
被撕碎了的回忆
被撕碎了的回忆 2021-02-13 16:01

Given a list of arbitrairly deep nested lists of arbitrary size, I would like an flat, depth-first iterator over all elements in the tree, but with path indicies as well such th

3条回答
  •  有刺的猬
    2021-02-13 16:37

    I think your own solution is alright, that there's nothing really simpler, and that Python's standard library stuff wouldn't help. But here's another way anyway, which works iteratively instead of recursively so it can handle very deeply nested lists.

    def flatten(l):
        stack = [enumerate(l)]
        path = [None]
        while stack:
            for path[-1], x in stack[-1]:
                if isinstance(x, list):
                    stack.append(enumerate(x))
                    path.append(None)
                else:
                    yield x, tuple(path)
                break
            else:
                stack.pop()
                path.pop()
    

    I keep the currently "active" lists on a stack of enumerate iterators, and the current index path as another stack. Then in a while-loop I always try to take the next element from the current list and handle it appropriately:

    • If that next element is a list, then I push its enumerate iterator on the stack and make room for the deeper index on the index path stack.
    • If that next element is a number, then I yield it along with its path.
    • If there was no next element in the current list, then I remove it (or rather its iterator) and its index spot from the stacks.

    Demo:

    >>> L = [[[1, 2, 3], [4, 5]], [6], [7,[8,9]], 10]
    >>> for entry in flatten(L):
            print(entry)
    
    (1, (0, 0, 0))
    (2, (0, 0, 1))
    (3, (0, 0, 2))
    (4, (0, 1, 0))
    (5, (0, 1, 1))
    (6, (1, 0))
    (7, (2, 0))
    (8, (2, 1, 0))
    (9, (2, 1, 1))
    (10, (3,))
    

    Note that if you process the entries on the fly, like printing does, then you could just yield the path as the list it is, i.e., use yield x, path. Demo:

    >>> for entry in flatten(L):
            print(entry)
    
    (1, [0, 0, 0])
    (2, [0, 0, 1])
    (3, [0, 0, 2])
    (4, [0, 1, 0])
    (5, [0, 1, 1])
    (6, [1, 0])
    (7, [2, 0])
    (8, [2, 1, 0])
    (9, [2, 1, 1])
    (10, [3])
    

    This way, the iterator only takes O(n) time for the whole iteration, where n is the total number of objects in the structure (both lists and numbers). Of course the printing increases the complexity, just like creating the tuples does. But that's then outside of the generator and the "fault" of the printing or whatever you're doing with each path. If you for example only look at each path's length instead of its contents, which takes O(1), then the whole thing even actually is O(n).

    All that said, again, I think your own solution is alright. And clearly simpler than this. And like I commented under @naomik's answer, I think your solution not being able to handle lists of depth around 1000 or more is rather irrelevant. One should not even have such a list in the first place. If one does, that's a mistake that should be fixed instead. If the list can also go wide, as in your case, and is balanced, then even with a branch factor of just 2 you'd run out of memory at a depth well under 100 and you wouldn't get anywhere near 1000. If the list can't go wide, then nested lists is the wrong data structure choice, plus you wouldn't be interested in the index path in the first place. If it can go wide but doesn't, then I'd say the creation algorithm should be improved (for example if it represents a sorted tree, add balancing).

    About my solution again: Besides its ability to handle arbitrarily deep lists and its efficiency, I find some of its details interesting to note:

    • You rarely ever see enumerate objects being stored somewhere. Usually they're just used in loops&Co directly, like for i, x in enumerate(l):.
    • Having the path[-1] spot ready and writing into it with for path[-1], x in ....
    • Using a for-loop with an immediate break and an else branch, to iterate over the next single value and handle ends gracefully without try/except and without next and some default.
    • If you do yield x, path, i.e., don't turn each path into a tuple, then you really need to process it directly during the iteration. For example if you do list(flatten(L)), then you get [(1, []), (2, []), (3, []), (4, []), (5, []), (6, []), (7, []), (8, []), (9, []), (10, [])]. That is, "all" index paths will be empty. Of course that's because there really only is one path object which I update and yield over and over again, and in the end its empty. This is very similar to itertools.groupby, where for example [list(g) for _, g in list(groupby('aaabbbb'))] gives you [[], ['b']]. And it's not a bad thing. I recently wrote about that extensively.

    Shorter version with one stack holding both indexes and enumerate objects alternatingly:

    def flatten(l):
        stack = [None, enumerate(l)]
        while stack:
            for stack[-2], x in stack[-1]:
                if isinstance(x, list):
                    stack += None, enumerate(x)
                else:
                    yield x, stack[::2]
                break
            else:
                del stack[-2:]
    

提交回复
热议问题