Python, Pairwise 'distance', need a fast way to do it

前端 未结 3 1417
渐次进展
渐次进展 2021-01-12 16:27

For a side project in my PhD, I engaged in the task of modelling some system in Python. Efficiency wise, my program hits a bottleneck in the following problem, which I\'ll e

3条回答
  •  北恋
    北恋 (楼主)
    2021-01-12 16:29

    You can use numpy's vectorization capabilities to speed up the calculation. My version computes all elements of the distance matrix at once and then sets the diagonal and the lower triangle to zero.

    def pairwise_distance2(s):
        # we need this because we're gonna divide by zero
        old_settings = np.seterr(all="ignore")
    
        N = N_segments # just shorter, could also use len(s)
    
        # we repeat p0 and p1 along all columns
        p0 = np.repeat(s[:,0:3].reshape((N, 1, 3)), N, axis=1)
        p1 = np.repeat(s[:,3:6].reshape((N, 1, 3)), N, axis=1)
        # and q0, q1 along all rows
        q0 = np.repeat(s[:,0:3].reshape((1, N, 3)), N, axis=0)
        q1 = np.repeat(s[:,3:6].reshape((1, N, 3)), N, axis=0)
    
        # element-wise dot product over the last dimension,
        # while keeping the number of dimensions at 3
        # (so we can use them together with the p* and q*)
        a = np.sum((p1 - p0) * (p1 - p0), axis=-1).reshape((N, N, 1))
        b = np.sum((p1 - p0) * (q1 - q0), axis=-1).reshape((N, N, 1))
        c = np.sum((q1 - q0) * (q1 - q0), axis=-1).reshape((N, N, 1))
        d = np.sum((p1 - p0) * (p0 - q0), axis=-1).reshape((N, N, 1))
        e = np.sum((q1 - q0) * (p0 - q0), axis=-1).reshape((N, N, 1))
    
        # same as above
        s = (b*e-c*d)/(a*c-b*b)
        t = (a*e-b*d)/(a*c-b*b)
    
        # almost same as above
        pairwise = np.sqrt(np.sum( (p0 + (p1 - p0) * s - ( q0 + (q1 - q0) * t))**2, axis=-1))
    
        # turn the error reporting back on
        np.seterr(**old_settings)
    
        # set everything at or below the diagonal to 0
        pairwise[np.tril_indices(N)] = 0.0
    
        return pairwise
    

    Now let's take it for a spin. With your example, N = 1000, I get a timing of

    %timeit pairwise_distance(List_of_segments)
    1 loops, best of 3: 10.5 s per loop
    
    %timeit pairwise_distance2(List_of_segments)
    1 loops, best of 3: 398 ms per loop
    

    And of course, the results are the same:

    (pairwise_distance2(List_of_segments) == pairwise_distance(List_of_segments)).all()
    

    returns True. I'm also pretty sure there's a matrix multiplication hidden somewhere in the algorithm, so there should be some potential for further speedup (and also cleanup).

    By the way: I've tried simply using numba first without success. Not sure why, though.

提交回复
热议问题