问题
I am trying to check if a number is in NumPy array of int8
s. I tried this, but it does not work.
from numba import njit
import numpy as np
@njit
def c(b):
return 9 in b
a = np.array((9, 10, 11), 'int8')
print(c(a))
The error I get is
Invalid use of Function(<built-in function contains>) with argument(s) of type(s): (array(int8, 1d, C), Literal[int](9))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at .\emptyList.py (6)
How can I fix this while still maintaining performance? The arrays will be checked for two values, 1 and -1, and are 32 items long. They are not sorted.
回答1:
Checking if two values are in an array
For checking only if two values occur in an array I would recommend a simple brute force algorithm.
Code
import numba as nb
import numpy as np
@nb.njit(fastmath=True)
def isin(b):
for i in range(b.shape[0]):
res=False
if (b[i]==-1):
res=True
if (b[i]==1):
res=True
return res
#Parallelized call to isin if the data is an array of shape (n,m)
@nb.njit(fastmath=True,parallel=True)
def isin_arr(b):
res=np.empty(b.shape[0],dtype=nb.boolean)
for i in nb.prange(b.shape[0]):
res[i]=isin(b[i,:])
return res
Performance
#Create some data (320MB)
A=(np.random.randn(10000000,32)-0.5)*5
A=A.astype(np.int8)
res=isin_arr(A) 11ms per call
So with this method I get a throughput of about 29GB/s which isn't far away from memory bandwith. You can also try to reduce the Testdatasize so that it will fit in L3-cache to avoid the memory-bandwith limit. With 3.2 MB Testdata I get a throuput of 100 GB/s (far beyond my the memory bandwith), which is a clear indicator that this implementation is memory bandwidth limited.
来源:https://stackoverflow.com/questions/54930852/python-numba-value-in-array