Vectorize large NumPy multiplication

前端 未结 2 423
醉酒成梦
醉酒成梦 2021-01-18 21:18

I am interested in calculating a large NumPy array. I have a large array A which contains a bunch of numbers. I want to calculate the sum of different combinati

相关标签:
2条回答
  • 2021-01-18 21:37

    Instead of dot you could use tensordot. Your current method is equivalent to:

    np.tensordot(A, Combinations, [2, 1]).transpose(2, 0, 1)
    

    Note the transpose at the end to put the axes in the correct order.

    Like dot, the tensordot function can call down to the fast BLAS/LAPACK libraries (if you have them installed) and so should be perform well for large arrays.

    0 讨论(0)
  • 2021-01-18 21:41

    np.dot() won't give give you the desired output , unless you involve extra step(s) that would probably include reshaping. Here's one vectorized approach using np.einsum to do it one shot without any extra memory overhead -

    Final_Product = np.einsum('ijk,lk->lij',A,Combinations)
    

    For completeness, here's with np.dot and reshaping as discussed earlier -

    M,N,R = A.shape
    Final_Product = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
    

    Runtime tests and verify output -

    In [138]: # Inputs ( smaller version of those listed in question )
         ...: A = np.random.uniform(0,1, (374, 138, 3))
         ...: Combinations = np.random.randint(0,3, (30,3))
         ...: 
    
    In [139]: %timeit np.array([  np.sum( A*cb, axis=2)  for cb in Combinations])
    1 loops, best of 3: 324 ms per loop
    
    In [140]: %timeit np.einsum('ijk,lk->lij',A,Combinations)
    10 loops, best of 3: 32 ms per loop
    
    In [141]: M,N,R = A.shape
    
    In [142]: %timeit A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
    100 loops, best of 3: 15.6 ms per loop
    
    In [143]: Final_Product =np.array([np.sum( A*cb, axis=2)  for cb in Combinations])
         ...: Final_Product2 = np.einsum('ijk,lk->lij',A,Combinations)
         ...: M,N,R = A.shape
         ...: Final_Product3 = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
         ...: 
    
    In [144]: print np.allclose(Final_Product,Final_Product2)
    True
    
    In [145]: print np.allclose(Final_Product,Final_Product3)
    True
    
    0 讨论(0)
提交回复
热议问题