Efficient tensor contraction with Python

坚强是说给别人听的谎言 提交于 2020-07-22 06:08:41

问题


I have a piece of code with a bottleneck calculation involving tensor contractions. Lets say I want to calculate a tensor A_{i,j,k,l}( X ) whose non-zero entries for a single x\in X are N ~ 10^5, and X represents a grid with M total points, with M~1000 approximately. For a single element of the tensor A, the rhs of the equation looks something like:

A_{ijkl}(M) = Sum_{m,n,p,q} S_{i,j, m,n }(M) B_{m,n,p,q}(M) T_{ p,q, k,l }(M)

In addition, the middle tensor B_{m,n,p,q}(M) is obtained by numerical convolution of arrays so that:

B_{m,n,p,q}(M) = ( L_{m,n} * F_{p,q} )(M)

where "*" is the convolution operator, and all tensors have appoximately the same number of elements as A. My problem has to do with efficiency of the sums; to compute a single rhs of A, it takes very long times given the complexity of the problem. I have a "keys" system, where each tensor element is accessed by its unique key combination ( ( p,q,k,l ) for T for example ) taken from a dictionary. Then the dictionary for that specific key gives the Numpy array associated to that key to perform an operation, and all operations (convolutions, multiplications...) are done using Numpy. I have seen that the most time consuming part is actually due to the nested loop (I loop over all keys (i,j,k,l) of the A tensor, and for each key, a rhs like the one above needs to be computed). Is there any efficient way to do this? Consider that:

1) Using simple numpy arrays of 4 +1 D results in high memory usage, since all tensors are of type complex 2 ) I have tried several approaches: Numba is quite limited when working with dictionaries, and some important Numpy features that I need are not currently supported. For instance, the numpy.convolve() only takes the first 2 arguments, but does not take the "mode" argument which reduces considerably the needed convolution interval in this case, I dont need the "full" output of the convolution

3) My most recent approach is trying to implement everything using Cython for this part... But this is quite time consuming as well as more error prone given the logic of the code.

Any ideas on how to deal with such complexity using Python?

Thanks!


回答1:


You have to make your question a bit more precise, which also includes a working code example which you have already tried. It is for example unclear, why you use dictionarys in this tensor contractions. Dictionary lookups looks to be a weard thing for this calculation, but maybe I didn't get the point what you really want to do.

Tensor contraction actually is very easy to implement in Python (Numpy), there are methods to find the best way to contract the tensors and they are really easy to use (np.einsum).

Creating some data (this should be part of the question)

import numpy as np
import time

i=20
j=20
k=20
l=20

m=20
n=20
p=20
q=20

#I don't know what complex 2 means, I assume it is complex128 (real and imaginary part are in float64)

#size of all arrays is 1.6e5
Sum_=np.random.rand(m,n,p,q).astype(np.complex128)
S_=np.random.rand(i,j,m,n).astype(np.complex128)
B_=np.random.rand(m,n,p,q).astype(np.complex128)
T_=np.random.rand(p,q,k,l).astype(np.complex128)

The naive way

This code is basically the same as writing it in loops using Cython or Numba without calling BLAS routines (ZGEMM) or optimizing the contraction order -> 8 nested loops to do the job.

t1=time.time()
A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_)
print(time.time()-t1)

This results in a very slow runtime of about 330 seconds.

How to increase the speed by a factor of 7700

%timeit A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
#42.9 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Why is this so much faster?

Lets have a look at the contraction path and the internals.

path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
print(path[1])

    #  Complete contraction:  mnpq,ijmn,mnpq,pqkl->ijkl
#         Naive scaling:  8
#     Optimized scaling:  6
#      Naive FLOP count:  1.024e+11
#  Optimized FLOP count:  2.562e+08
#   Theoretical speedup:  399.750
#  Largest intermediate:  1.600e+05 elements
#--------------------------------------------------------------------------
#scaling                  current                                remaining
#--------------------------------------------------------------------------
#   4             mnpq,mnpq->mnpq                     ijmn,pqkl,mnpq->ijkl
#   6             mnpq,ijmn->ijpq                          pqkl,ijpq->ijkl
#   6             ijpq,pqkl->ijkl                               ijkl->ijkl

and

path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal",einsum_call=True)
print(path[1])

#[((2, 0), set(), 'mnpq,mnpq->mnpq', ['ijmn', 'pqkl', 'mnpq'], False), ((2, 0), {'n', 'm'}, 'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0), {'p', 'q'}, 'ijpq,pqkl->ijkl', ['ijkl'], True)]

Doing the contraction in multiple well choosen steps reduces the required flops by a factor of 400. But thats not the only thing what einsum does here. Just have a look at 'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0) the True stands for a BLAS contraction -> tensordot call -> (matrix matix multiplication).

Internally this looks basically as follows:

#consider X as a 4th order tensor {mnpq}
#consider Y as a 4th order tensor {ijmn}

X_=X.reshape(m*n,p*q)       #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
Y_=Y.reshape(i*j,m*n)       #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
res=np.dot(Y_,X_)           #-> dot is just a wrapper for highly optimized BLAS functions, in case of complex128 ZGEMM
output=res.reshape(i,j,p,q) #-> just another view on the data (4D), costs almost nothing (no copy, just a view)


来源:https://stackoverflow.com/questions/62395075/efficient-tensor-contraction-with-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!