Python, Pairwise 'distance', need a fast way to do it

前端 未结 3 1416
渐次进展
渐次进展 2021-01-12 16:27

For a side project in my PhD, I engaged in the task of modelling some system in Python. Efficiency wise, my program hits a bottleneck in the following problem, which I\'ll e

相关标签:
3条回答
  • 2021-01-12 16:29

    You can use numpy's vectorization capabilities to speed up the calculation. My version computes all elements of the distance matrix at once and then sets the diagonal and the lower triangle to zero.

    def pairwise_distance2(s):
        # we need this because we're gonna divide by zero
        old_settings = np.seterr(all="ignore")
    
        N = N_segments # just shorter, could also use len(s)
    
        # we repeat p0 and p1 along all columns
        p0 = np.repeat(s[:,0:3].reshape((N, 1, 3)), N, axis=1)
        p1 = np.repeat(s[:,3:6].reshape((N, 1, 3)), N, axis=1)
        # and q0, q1 along all rows
        q0 = np.repeat(s[:,0:3].reshape((1, N, 3)), N, axis=0)
        q1 = np.repeat(s[:,3:6].reshape((1, N, 3)), N, axis=0)
    
        # element-wise dot product over the last dimension,
        # while keeping the number of dimensions at 3
        # (so we can use them together with the p* and q*)
        a = np.sum((p1 - p0) * (p1 - p0), axis=-1).reshape((N, N, 1))
        b = np.sum((p1 - p0) * (q1 - q0), axis=-1).reshape((N, N, 1))
        c = np.sum((q1 - q0) * (q1 - q0), axis=-1).reshape((N, N, 1))
        d = np.sum((p1 - p0) * (p0 - q0), axis=-1).reshape((N, N, 1))
        e = np.sum((q1 - q0) * (p0 - q0), axis=-1).reshape((N, N, 1))
    
        # same as above
        s = (b*e-c*d)/(a*c-b*b)
        t = (a*e-b*d)/(a*c-b*b)
    
        # almost same as above
        pairwise = np.sqrt(np.sum( (p0 + (p1 - p0) * s - ( q0 + (q1 - q0) * t))**2, axis=-1))
    
        # turn the error reporting back on
        np.seterr(**old_settings)
    
        # set everything at or below the diagonal to 0
        pairwise[np.tril_indices(N)] = 0.0
    
        return pairwise
    

    Now let's take it for a spin. With your example, N = 1000, I get a timing of

    %timeit pairwise_distance(List_of_segments)
    1 loops, best of 3: 10.5 s per loop
    
    %timeit pairwise_distance2(List_of_segments)
    1 loops, best of 3: 398 ms per loop
    

    And of course, the results are the same:

    (pairwise_distance2(List_of_segments) == pairwise_distance(List_of_segments)).all()
    

    returns True. I'm also pretty sure there's a matrix multiplication hidden somewhere in the algorithm, so there should be some potential for further speedup (and also cleanup).

    By the way: I've tried simply using numba first without success. Not sure why, though.

    0 讨论(0)
  • 2021-01-12 16:30

    This is more of a meta answer, at least for starters. Your problem might already be in "my program hits a bottleneck" and "I realize this is extremely inefficient".

    Extremely inefficient? By what measure? Do you have comparison? Is your code too slow to finish in a reasonable amount of time? What is a reasonable amount of time for you? Can you throw more computing power at the problem? Equally important -- do you use a proper infrastructure to run your code on (numpy/scipy compiled with vendor compilers, possibly with OpenMP support)?

    Then, if you have answers for all of the questions above and need to further optimize your code -- where is the bottleneck in your current code exactly? Did you profile it? It the body of the loop possibly much more heavy-weight than the evaluation of the loop itself? If so, then "the loop" is not your bottleneck and you do not need to worry about the nested loop in the first place. Optimize the body at first, possibly by coming up with unorthodox matrix representations of your data so that you can perform all these single calculations in one step -- by matrix multiplication, for instance. If your problem is not solvable by efficient linear algebra operations, you can start writing a C extension or use Cython or use PyPy (which just very recently got some basic numpy support!). There are endless possibilities for optimizing -- the questions really are: how close to a practical solution are you already, how much do you need to optimize, and how much of an effort are you willing to invest.

    Disclaimer: I have done non-canonical pairwise-distance stuff with scipy/numpy for my PhD, too ;-). For one particular distance metric, I ended up coding the "pairwise" part in simple Python (i.e. I also used the doubly-nested loop), but spent some effort in getting the body as efficient as possible (with a combination of i) a cryptical matrix multiplication representation of my problem and ii) using bottleneck).

    0 讨论(0)
  • 2021-01-12 16:51

    You can use it something like this:

    def distance3d (p, q):
        if (p == q).all ():
            return 0
    
        p0 = p[0:3]
        p1 = p[3:6]
        q0 = q[0:3]
        q1 = q[3:6]
    
        ...  # Distance computation using the formula above.
    
    print (distance.cdist (List_of_segments, List_of_segments, distance3d))
    

    It doesn't seem to be any faster, though, since it executes the same loop internally.

    0 讨论(0)
提交回复
热议问题