permutations with unique values

前端 未结 19 1466
爱一瞬间的悲伤
爱一瞬间的悲伤 2020-11-22 01:53

itertools.permutations generates where its elements are treated as unique based on their position, not on their value. So basically I want to avoid duplicates like this:

19条回答
  •  青春惊慌失措
    2020-11-22 02:34

    Adapted to remove recursion, use a dictionary and numba for high performance but not using yield/generator style so memory usage is not limited:

    import numba
    
    @numba.njit
    def perm_unique_fast(elements): #memory usage too high for large permutations
        eset = set(elements)
        dictunique = dict()
        for i in eset: dictunique[i] = elements.count(i)
        result_list = numba.typed.List()
        u = len(elements)
        for _ in range(u): result_list.append(0)
        s = numba.typed.List()
        results = numba.typed.List()
        d = u
        while True:
            if d > 0:
                for i in dictunique:
                    if dictunique[i] > 0: s.append((i, d - 1))
            i, d = s.pop()
            if d == -1:
                dictunique[i] += 1
                if len(s) == 0: break
                continue
            result_list[d] = i
            if d == 0: results.append(result_list[:])
            dictunique[i] -= 1
            s.append((i, -1))
        return results
    
    import timeit
    l = [2, 2, 3, 3, 4, 4, 5, 5, 6, 6]
    %timeit list(perm_unique(l))
    #377 ms ± 26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    ltyp = numba.typed.List()
    for x in l: ltyp.append(x)
    %timeit perm_unique_fast(ltyp)
    #293 ms ± 3.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    assert list(sorted(perm_unique(l))) == list(sorted([tuple(x) for x in perm_unique_fast(ltyp)]))
    

    About 30% faster but still suffers a bit due to list copying and management.

    Alternatively without numba but still without recursion and using a generator to avoid memory issues:

    def perm_unique_fast_gen(elements):
        eset = set(elements)
        dictunique = dict()
        for i in eset: dictunique[i] = elements.count(i)
        result_list = list() #numba.typed.List()
        u = len(elements)
        for _ in range(u): result_list.append(0)
        s = list()
        d = u
        while True:
            if d > 0:
                for i in dictunique:
                    if dictunique[i] > 0: s.append((i, d - 1))
            i, d = s.pop()
            if d == -1:
                dictunique[i] += 1
                if len(s) == 0: break
                continue
            result_list[d] = i
            if d == 0: yield result_list
            dictunique[i] -= 1
            s.append((i, -1))
    

提交回复
热议问题