consider the array x
and delta variable d
np.random.seed([3,1415])
x = np.random.randint(100, size=10)
d = 10
For
Strategy
x
is not necessarily sorted, we'll sort it and track the sorting permutation via argsort
so we can reverse the permutation.np.searchsorted
on x
with x - d
to find the starting place for when values of x
start to exceed x - d
.np.searchsorted
parameter side='right'
and using x + d
define method presented in question as pir1
def pir1(a, d):
return (np.abs(a[:, None] - a) <= d).sum(-1)
We'll define a new function pir2
def pir2(a, d):
s = x.argsort()
a_ = a[s]
return (
a_.searchsorted(a_ + d, 'right')
- a_.searchsorted(a_ - d)
)[s.argsort()]
demo
pir1(x, d)
[5 2 1 2 5 1 5 5 5 1]
pir1(x, d)
[5 2 1 2 5 1 5 5 5 1]
timing
pir2
is the clear winner!
code
functions
def pir1(a, d):
return (np.abs(a[:, None] - a) <= d).sum(-1)
def pir2(a, d):
s = x.argsort()
a_ = a[s]
return (
a_.searchsorted(a_ + d, 'right')
- a_.searchsorted(a_ - d)
)[s.argsort()]
#######################
# From Divakar's post #
#######################
def pir3(a,d): # Short & less efficient
sidx = a.argsort()
p1 = a.searchsorted(a+d,'right',sorter=sidx)
p2 = a.searchsorted(a-d,sorter=sidx)
return p1 - p2
def pir4(a, d): # Long & more efficient
s = a.argsort()
y = np.empty(s.size,dtype=np.int64)
y[s] = np.arange(s.size)
a_ = a[s]
return (
a_.searchsorted(a_ + d, 'right')
- a_.searchsorted(a_ - d)
)[y]
test
from timeit import timeit
results = pd.DataFrame(
index=np.arange(1, 50),
columns=['pir%s' %i for i in range(1, 5)])
for i in results.index:
np.random.seed([3,1415])
x = np.random.randint(1000000, size=i)
for j in results.columns:
setup = 'from __main__ import x, {}'.format(j)
results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=10000)
results.plot()
extended out to larger arrays
got rid of pir1
from timeit import timeit
results = pd.DataFrame(
index=np.arange(1, 11) * 1000,
columns=['pir%s' %i for i in range(2, 5)])
for i in results.index:
np.random.seed([3,1415])
x = np.random.randint(1000000, size=i)
for j in results.columns:
setup = 'from __main__ import x, {}'.format(j)
results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=100)
results.insert(0, 'pir1', 0)
results.plot()