How to sum of squares of sum with memory limitations?

后端 未结 2 1641
闹比i
闹比i 2021-01-17 02:53

This is a follow up of this question:

How to do a sum of sums of the square of sum of sums?

Where I was asking for help to use einsum (to achieve a great spe

相关标签:
2条回答
  • 2021-01-17 03:00

    You could sum over one index first and then continue the multiplication. I have also tried versions with numexpr thrown in to the last multiplication and reduction operations but it doesn't seem to help too much.

    def fun3(Fu, Fv, Fx, Fy, P, B):
        P = P[None, None, ...]
        Fu = Fu[:, None, None, None, :]
        Fx = Fx[None, None, :, None, :]
        Fv = Fv[:, None, None, :]
        Fy = Fy[None, :, None, :]
        B = B[None, None, ...]
        return np.sum((P*np.sum(Fu*Fx*np.sum(Fv*Fy*B, axis=-1)[None, :, None, :, :], axis=-1))**2, axis=(2, 3))
    

    It's much faster on my computer:

    jit2 : 7.06 s

    fun3: 0.144 s

    Edit: Minor improvement - multiply first then square.

    Edit2: Leveraging what each does best (numexpr - multiplication, numpy - dot/tensordot, summation) you can still improve over the fun3 more than 20 times.

    def fun4(Fu, Fv, Fx, Fy, P, B):
        P = P[None, None, ...]
        Fu = Fu[:, None, :]
        Fx = Fx[None, ...]
        Fy = Fy[:, None, :]
        B = B[None, ...]
        s = ne.evaluate('Fu*Fx')
        r = np.tensordot(Fv, ne.evaluate('Fy*B'), axes=(1, 2))
        I = np.tensordot(s, r, axes=(2, 2)).swapaxes(1, 2)
        r = ne.evaluate('(P*I)**2')
        r = np.sum(r, axis=(2, 3))
        return r
    

    fun4: 0.007 s

    Moreover, since fun8 is not that memory hungry anymore (due to the smart tensordot) you can multiply much bigger arrays and see it use multiple cores.

    0 讨论(0)
  • 2021-01-17 03:06

    The solution below presents the 3 different methods to do the simple sum of sums and 4 different methods to do the sum of squares.

    sum of sums 3 methods - for loop, JIT for loops, einsum (none run into memory problems)

    sum of square of sums 4 methods - for loop, JIT for loops, expanded einsum, intermediate einsum

    Here the first three don't run into memory problems and the for loop and the expanded einsum run into speed problems. This leaves the JIT solution seeming the best.

    import numpy as np
    import time
    from numba import jit
    
    def fun1(Fu, Fv, Fx, Fy, P, B):
        Nu = Fu.shape[0]
        Nv = Fv.shape[0]
        Nx = Fx.shape[0]
        Ny = Fy.shape[0]
        Nk = Fu.shape[1]
        Nl = Fv.shape[1]
        I1 = np.zeros([Nu, Nv])
        for iu in range(Nu):
            for iv in range(Nv):
                for ix in range(Nx):
                    for iy in range(Ny):
                        S = 0.
                        for ik in range(Nk):
                            for il in range(Nl):
                                S += Fu[iu,ik]*Fv[iv,il]*Fx[ix,ik]*Fy[iy,il]*P[ix,iy]*B[ik,il]
                        I1[iu, iv] += S
        return I1
    
    def fun2(Fu, Fv, Fx, Fy, P, B):
        Nu = Fu.shape[0]
        Nv = Fv.shape[0]
        Nx = Fx.shape[0]
        Ny = Fy.shape[0]
        Nk = Fu.shape[1]
        Nl = Fv.shape[1]
        I2 = np.zeros([Nu, Nv])
        for iu in range(Nu):
            for iv in range(Nv):
                for ix in range(Nx):
                    for iy in range(Ny):
                        S = 0.
                        for ik in range(Nk):
                            for il in range(Nl):
                                S += Fu[iu,ik]*Fv[iv,il]*Fx[ix,ik]*Fy[iy,il]*P[ix,iy]*B[ik,il]
                        I2[iu, iv] += S**2.
        return I2
    
    if __name__ == '__main__':
    
        Nx = 30
        Ny = 40
        Nk = 50
        Nl = 60
        Nu = 70
        Nv = 8
        Fx = np.random.rand(Nx, Nk)
        Fy = np.random.rand(Ny, Nl)
        Fu = np.random.rand(Nu, Nk)
        Fv = np.random.rand(Nv, Nl)
        P = np.random.rand(Nx, Ny)
        B = np.random.rand(Nk, Nl)
        fjit1 = jit(fun1)
        fjit2 = jit(fun2)
    
        # For loop - becomes too slow so commented out
        # t = time.time()
        # I1 = fun1(Fu, Fv, Fx, Fy, P, B)
        # print 'fun1    :', time.time() - t
    
        # JIT compiled for loop - After a certain point beats einsum
        t = time.time()
        I1jit = fjit1(Fu, Fv, Fx, Fy, P, B)
        print 'jit1    :', time.time() - t
    
        # einsum great solution when no squaring is needed
        t = time.time()
        I1_ = np.einsum('uk, vl, xk, yl, xy, kl->uv', Fu, Fv, Fx, Fy, P, B)
        print '1 einsum:', time.time() - t
    
        # For loop - becomes too slow so commented out
        # t = time.time()
        # I2 = fun2(Fu, Fv, Fx, Fy, P, B)
        # print 'fun2    :', time.time() - t
    
        # JIT compiled for loop - After a certain point beats einsum
        t = time.time()
        I2jit = fjit2(Fu, Fv, Fx, Fy, P, B)
        print 'jit2    :', time.time() - t
    
        # Expanded einsum - As the size increases becomes very very slow
        # t = time.time()
        # I2_ = np.einsum('uk,vl,xk,yl,um,vn,xm,yn,kl,mn,xy->uv', Fu,Fv,Fx,Fy,Fu,Fv,Fx,Fy,B,B,P**2)
        # print '2 einsum:', time.time() - t
    
        # Intermediate einsum - As the sizes increase memory can become an issue
        t = time.time()
        temp = np.einsum('uk, vl, xk, yl, xy, kl->uvxy', Fu, Fv, Fx, Fy, P, B)
        I2__ = np.einsum('uvxy->uv', np.square(temp))
        print '2 einsum:', time.time() - t
    
        # print 'I1 == I1_   :', np.allclose(I1, I1_)
        print 'I1_ == Ijit1_:', np.allclose(I1_, I1jit)
        # print 'I2 == I2_   :', np.allclose(I2, I2_)
        print 'I2_ == Ijit2_:', np.allclose(I2__, I2jit)
    

    Comment: Please feel free to edit / improve this answer. It would be nice if someone had any suggestions with regards to making this parallel.

    0 讨论(0)
提交回复
热议问题