问题
I am trying to speed up a code that is using Numpy's where() function. There are two calls to where()
, which return an array of indices for where the statement is evaluated as True
, which are then compared for overlap with numpy's intersect1d() function, of which the length of the intersection is returned.
import numpy as np
def find_match(x,y,z):
A = np.where(x == z)
B = np.where(y == z)
#A = True
#B = True
return len(np.intersect1d(A,B))
N = np.power(10, 8)
M = 10
X = np.random.randint(M, size=N)
Y = np.random.randint(M, size=N)
Z = np.random.randint(M, size=N)
#print(X,Y,Z)
print(find_match(X,Y,Z))
Timing:
This code takes about 8 seconds on my laptop. If I replace both the
np.where()
withA=True
andB=True
, then it takes about 5 seconds. If I replace only one of thenp.where()
then it takes about 6 seconds.Scaling up, by switching to
N = np.power(10, 9)
, the code takes 87 seconds. Replacing both thenp.where()
statements results in the code takes 51 seconds. Replacing just one of thenp.where()
takes about 61 seconds.
My question: How can I merge the two np.where
statements that can speed up the code?
What I've tried? Actually, this iteration of the code has improved speed (~4x) by replacing a slower lookup with for-loops. Multiprocessing will be used at a higher level in this code, so I can't apply it also here.
For the record, the actual data are strings, so doing integer math won't be helpful
Version info:
Python 3.9.1 (default, Jan 8 2021, 17:17:43)
[Clang 12.0.0 (clang-1200.0.32.28)] on darwin
>>> import numpy
>>> print(numpy.__version__)
1.19.5
回答1:
Does this help?
def find_match2(x, y, z):
return len(np.nonzero(np.logical_and(x == z, y == z))[0])
Sample run:
In [227]: print(find_match(X,Y,Z))
1000896
In [228]: print(find_match2(X,Y,Z))
1000896
In [229]: %timeit find_match(X,Y,Z)
2.37 s ± 70.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [230]: %timeit find_match2(X,Y,Z)
272 ms ± 9.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I've added np.random.seed(210)
before creating the arrays for the sake of reproducibility.
回答2:
Two versions that scale differently depending on size:
def find_match1(x,y,z):
return (x==y).astype(int) @ (y==z).astype(int) #equality and summation in one step
def find_match2(x,y,z):
out = np.zeros_like(x)
np.equal(x, y, out = out, where = np.equal(y, z)) #only calculates x==y if y==z
return out.sum()
Testing different data sizes:
N = np.power(10, 7)
...
%timeit find_match(X,Y,Z)
206 ms ± 12.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit find_match1(X,Y,Z)
70.7 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit find_match2(X,Y,Z)
74.7 ms ± 3.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
N = np.power(10, 8)
...
%timeit find_match(X,Y,Z)
2.51 s ± 168 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit find_match1(X,Y,Z)
886 ms ± 154 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit find_match2(X,Y,Z)
776 ms ± 26.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
EDIT: since @Tonechas's is faster than both, here's a numba
method:
from numba import njit
@njit
def find_match_jit(x, y, z):
out = 0
for i, j, k in zip(x, y, z):
if i == j and j == k:
out += 1
return out
find_match_jit(X,Y,Z) #run it once to compile
Out[]: 1001426
%timeit find_match_jit(X,Y,Z) # N = 10**8
204 ms ± 13.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
If threading is allowed:
@njit(parallel = True)
def find_match_jit_p(x, y, z):
xy = x == y
yz = y == z
return np.logical_and(xy, yz).sum()
find_match_jit_p(X,Y,Z)
Out[]: 1001426
%timeit find_match_jit_p(X,Y,Z)
84.6 ms ± 2.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
来源:https://stackoverflow.com/questions/65823425/combine-numpy-where-statements