问题
I have some code which loops through a large set of itertools.combinations
,
which is now a performance bottleneck. I'm trying to turn to numba
's @jit(nopython=True)
to speed it up, but I'm running into some issues.
First, it seems numba can't handle itertools.combinations
itself, per this small example:
import itertools
import numpy as np
from numba import jit
arr = [1, 2, 3]
c = 2
@jit(nopython=True)
def using_it(arr, c):
return itertools.combinations(arr, c)
for i in using_it(arr, c):
print(i)
throw error: numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'combinations' of type Module(<module 'itertools' (built-in)>)
After some googling, I found this github issue where the questioner proposed this numba-safe function for calculating permutations:
@jit(nopython=True)
def permutations(A, k):
r = [[i for i in range(0)]]
for i in range(k):
r = [[a] + b for a in A for b in r if (a in b)==False]
return r
Leveraging that, I can then easily filter down to combinations:
@jit(nopython=True)
def combinations(A, k):
return [item for item in permutations(A, k) if sorted(item) == item]
Now I can run that combinations
function without errors and get the correct result. However, this is now dramatically slower with the @jit(nopython=True)
than without it. Running this timing test:
A = list(range(20)) # numba throws 'cannot determine numba type of range' w/o list
k = 2
start = pd.Timestamp.utcnow()
print(combinations(A, k))
print(f"took {pd.Timestamp.utcnow() - start}")
clocks in at 2.6 seconds with the numba @jit(nopython=True)
decorators, and under 1/000 of a second with them commented out. So that's not really a workable solution for me either.
回答1:
There is not much to gain with Numba in this case as itertools.combinations is written in C.
If you want to benchmark it, here is a Numba / Python implementation of what itertools.combinatiions
does:
@jit(nopython=True)
def using_numba(pool, r):
n = len(pool)
indices = list(range(r))
empty = not(n and (0 < r <= n))
if not empty:
result = [pool[i] for i in indices]
yield result
while not empty:
i = r - 1
while i >= 0 and indices[i] == i + n - r:
i -= 1
if i < 0:
empty = True
else:
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
result = [pool[i] for i in indices]
yield result
On my machine, this is about 15 times slower than itertools.combinations
. Getting the permutations and filtering the combinations would certainly be even slower.
来源:https://stackoverflow.com/questions/61262188/numba-safe-version-of-itertools-combinations