Efficient elementwise argmin of matrix-vector difference

后端 未结 1 1125
滥情空心
滥情空心 2021-02-15 14:06

Suppose an array a.shape == (N, M) and a vector v.shape == (N,). The goal is to compute argmin of abs of v subtr

1条回答
  •  走了就别回头了
    2021-02-15 14:47

    Inspired by this post, we can leverage np.searchsorted -

    def find_closest(a, v):
        sidx = v.argsort()
        v_s = v[sidx]
        idx = np.searchsorted(v_s, a)
        idx[idx==len(v)] = len(v)-1
        idx0 = (idx-1).clip(min=0)
        
        m = np.abs(a-v_s[idx]) >= np.abs(v_s[idx0]-a)
        m[idx==0] = 0
        idx[m] -= 1
        out = sidx[idx]
        return out
    

    Some more perf. boost with numexpr on large datasets :

    import numexpr as ne
    
    def find_closest_v2(a, v):
        sidx = v.argsort()
        v_s = v[sidx]
        idx = np.searchsorted(v_s, a)
        idx[idx==len(v)] = len(v)-1
        idx0 = (idx-1).clip(min=0)
        
        p1 = v_s[idx]
        p2 = v_s[idx0]
        m = ne.evaluate('(idx!=0) & (abs(a-p1) >= abs(p2-a))', {'p1':p1, 'p2':p2, 'idx':idx})
        idx[m] -= 1
        out = sidx[idx]
        return out
    

    Timings

    Setup :

    N,M = 500,100000
    a = np.random.rand(N,M)
    v = np.random.rand(N)
    
    In [22]: %timeit find_closest_v2(a, v)
    4.35 s ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [23]: %timeit find_closest(a, v)
    4.69 s ± 173 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    0 讨论(0)
提交回复
热议问题