how to speed up a vector cross product calculation

后端 未结 2 875
死守一世寂寞
死守一世寂寞 2020-12-03 15:07

Hi I\'m relatively new here and trying to do some calculations with numpy. I\'m experiencing a long elapse time from one particular calculation and can\'t work out any faste

相关标签:
2条回答
  • 2020-12-03 15:55

    If you look at the source code of np.cross, it basically moves the xyz dimension to the front of the shape tuple for all arrays, and then has the calculation of each of the components spelled out like this:

    x = a[1]*b[2] - a[2]*b[1]
    y = a[2]*b[0] - a[0]*b[2]
    z = a[0]*b[1] - a[1]*b[0]
    

    In your case, each of those products requires allocating huge arrays, so the overall behavior is not very efficient.

    Lets set up some test data:

    u = np.random.rand(1000, 3)
    v = np.random.rand(2000, 3)
    
    In [13]: %timeit s1 = np.cross(u[:, None, :], v[None, :, :])
    1 loops, best of 3: 591 ms per loop
    

    We can try to compute it using Levi-Civita symbols and np.einsum as follows:

    eijk = np.zeros((3, 3, 3))
    eijk[0, 1, 2] = eijk[1, 2, 0] = eijk[2, 0, 1] = 1
    eijk[0, 2, 1] = eijk[2, 1, 0] = eijk[1, 0, 2] = -1
    
    In [14]: %timeit s2 = np.einsum('ijk,uj,vk->uvi', eijk, u, v)
    1 loops, best of 3: 706 ms per loop
    
    In [15]: np.allclose(s1, s2)
    Out[15]: True
    

    So while it works, it has worse performance. The thing is that np.einsum has trouble when there are more than two operands, but has optimized pathways for two or less. So we can try to rewrite it in two steps, to see if it helps:

    In [16]: %timeit s3 = np.einsum('iuk,vk->uvi', np.einsum('ijk,uj->iuk', eijk, u), v)
    10 loops, best of 3: 63.4 ms per loop
    
    In [17]: np.allclose(s1, s3)
    Out[17]: True
    

    Bingo! Close to an order of magnitude improvement...

    Some performance figures for NumPy 1.11.0 with a=numpy.random.rand(n,3), b=numpy.random.rand(n,3):

    The nested einsum is about twice as fast as cross for the largest n tested.

    0 讨论(0)
  • 2020-12-03 16:00

    While writing dynamic simulations for underwater vehicles I have found this method for fast cross product:

    https://github.com/simena86/Simulink-Underwater-Robotics-Simulator/blob/master/3rdparty/gnc_mfiles/Smtrx.m

    Which works well, it is written in Matlab but the code is very simple. Just read the comments at the top.

    0 讨论(0)
提交回复
热议问题