count how many elements in a numpy array are within delta of every other element

后端 未结 2 1216
面向向阳花
面向向阳花 2020-12-18 08:12

consider the array x and delta variable d

np.random.seed([3,1415])
x = np.random.randint(100, size=10)
d = 10

For

相关标签:
2条回答
  • 2020-12-18 08:48

    Strategy

    • Since x is not necessarily sorted, we'll sort it and track the sorting permutation via argsort so we can reverse the permutation.
    • We'll use np.searchsorted on x with x - d to find the starting place for when values of x start to exceed x - d.
    • Do it again on the other side except we'll have to use the np.searchsorted parameter side='right' and using x + d
    • Take the difference between right and left searchsorts to calculate number of elements that are within +/- d of each element
    • Use argsort to reverse the sorting permutation

    define method presented in question as pir1

    def pir1(a, d):
        return (np.abs(a[:, None] - a) <= d).sum(-1)
    

    We'll define a new function pir2

    def pir2(a, d):
        s = x.argsort()
        a_ = a[s]
        return (
            a_.searchsorted(a_ + d, 'right')
            - a_.searchsorted(a_ - d)
        )[s.argsort()]
    

    demo

    pir1(x, d)
    
    [5 2 1 2 5 1 5 5 5 1]    
    

    pir1(x, d)
    
    [5 2 1 2 5 1 5 5 5 1]    
    

    timing
    pir2 is the clear winner!

    code

    functions

    def pir1(a, d):
        return (np.abs(a[:, None] - a) <= d).sum(-1)
    
    def pir2(a, d):
        s = x.argsort()
        a_ = a[s]
        return (
            a_.searchsorted(a_ + d, 'right')
            - a_.searchsorted(a_ - d)
        )[s.argsort()]
    
    #######################
    # From Divakar's post #
    #######################
    def pir3(a,d):  # Short & less efficient
        sidx = a.argsort()
        p1 = a.searchsorted(a+d,'right',sorter=sidx)
        p2 = a.searchsorted(a-d,sorter=sidx)
        return p1 - p2
    
    def pir4(a, d):   # Long & more efficient
        s = a.argsort()
    
        y = np.empty(s.size,dtype=np.int64)
        y[s] = np.arange(s.size)
    
        a_ = a[s]
        return (
            a_.searchsorted(a_ + d, 'right')
            - a_.searchsorted(a_ - d)
        )[y]
    

    test

    from timeit import timeit
    
    results = pd.DataFrame(
        index=np.arange(1, 50),
        columns=['pir%s' %i for i in range(1, 5)])
    
    for i in results.index:
        np.random.seed([3,1415])
        x = np.random.randint(1000000, size=i)
        for j in results.columns:
            setup = 'from __main__ import x, {}'.format(j)
            results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=10000)
    
    results.plot()
    


    extended out to larger arrays
    got rid of pir1

    from timeit import timeit
    
    results = pd.DataFrame(
        index=np.arange(1, 11) * 1000,
        columns=['pir%s' %i for i in range(2, 5)])
    
    for i in results.index:
        np.random.seed([3,1415])
        x = np.random.randint(1000000, size=i)
        for j in results.columns:
            setup = 'from __main__ import x, {}'.format(j)
            results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=100)
    
    results.insert(0, 'pir1', 0)
    
    results.plot()
    

    0 讨论(0)
  • 2020-12-18 08:50

    Listed in this post are two more variants based on the searchsorted strategy from OP's answer post.

    def pir3(a,d):  # Short & less efficient
        sidx = a.argsort()
        p1 = a.searchsorted(a+d,'right',sorter=sidx)
        p2 = a.searchsorted(a-d,sorter=sidx)
        return p1 - p2
    
    def pir4(a, d):   # Long & more efficient
        s = a.argsort()
    
        y = np.empty(s.size,dtype=np.int64)
        y[s] = np.arange(s.size)
    
        a_ = a[s]
        return (
            a_.searchsorted(a_ + d, 'right')
            - a_.searchsorted(a_ - d)
        )[y]
    

    The more efficient approach derives the efficient idea to get s.argsort() from this post.

    Runtime test -

    In [155]: # Inputs
         ...: a = np.random.randint(0,1000000,(10000))
         ...: d = 10
    
    
    In [156]: %timeit pir2(a,d) #@ piRSquared's post solution
         ...: %timeit pir3(a,d)
         ...: %timeit pir4(a,d)
         ...: 
    100 loops, best of 3: 2.43 ms per loop
    100 loops, best of 3: 4.44 ms per loop
    1000 loops, best of 3: 1.66 ms per loop
    
    0 讨论(0)
提交回复
热议问题