Performance issue in python with nested loop

后端 未结 2 1714
醉话见心
醉话见心 2021-01-27 09:41

I was able to improve a code written in python a lot with numpy because of the dot product. Now I still have one part of the code which is still very slow. I still don\'t unders

2条回答
  •  旧时难觅i
    2021-01-27 10:29

    I'm attempted to re-create the conditions that the question was interested in, but first a smaller test case to illustrate a strategy. First the author's original implementation:

    import numpy as np
    import numba as nb
    import numpy
    
    def func(re, ws, a, l, r):
    
        for x1 in range(a**l):
            for x2 in range(a**l):
                for x3 in range(a**l):
                    f11 = 0
                    cv1 = numpy.ndarray.sum(
                    numpy.absolute(numpy.subtract(ws[x1], ws[x2])))
                    cv2 = numpy.ndarray.sum(
                    numpy.absolute(numpy.subtract(ws[x1], ws[x3])))
                    if cv1 == 0:
                        f11 += 1
                    if cv2 == 0:
                        f11 += 1
                    re[x1][x2][x3] = 1.0*r/(a**l-2)*(numpy.product(numpy.absolute(
                                numpy.subtract((2*ws[x1]+ws[x2]+ws[x3]), 2)))-f11)
                    f11 *= 1.0*(1-r)/2
                    re[x1][x2][x3] += f11
    

    Now with a simple translation to Numba, which is really well suited to these types of deeply nested looping problems when you're dealing with numpy arrays and numerical calculations:

    @nb.njit
    def func2(re, ws, a, l, r):
        for x1 in range(a**l):
            for x2 in range(a**l):
                for x3 in range(a**l):
                    f11 = 0.0
                    cv1 = np.sum(np.abs(ws[x1] - ws[x2]))
                    cv2 = np.sum(np.abs(ws[x1] - ws[x3]))
    
                    if cv1 == 0:
                        f11 += 1
                    if cv2 == 0:
                        f11 += 1
                    y = np.prod(np.abs(2*ws[x1]+ws[x2]+ws[x3] -  2)) - f11
                    re[x1,x2,x3] = 1.0*r/(a**l-2)*y
                    f11 *= 1.0*(1-r)/2
                    re[x1,x2,x3] += f11
    

    and then with some further optimizations to get rid of temporary array creation:

    @nb.njit
    def func3(re, ws, a, l, r):
        for x1 in range(a**l):
            for x2 in range(a**l):
                for x3 in range(a**l):
                    f11 = 0.0
                    cv1 = 0.0
                    cv2 = 0.0
                    for i in range(ws.shape[1]):
                        cv1 += np.abs(ws[x1,i] - ws[x2,i])
                        cv2 += np.abs(ws[x1,i] - ws[x3,i])
    
                    if cv1 == 0:
                        f11 += 1
                    if cv2 == 0:
                        f11 += 1
                    y = 1.0
                    for i in range(ws.shape[1]):
                        y *= np.abs(2.0*ws[x1,i] + ws[x2,i] + ws[x3,i] - 2)
                    y -= f11
                    re[x1,x2,x3] = 1.0*r/(a**l-2)*y
                    f11 *= 1.0*(1-r)/2
                    re[x1,x2,x3] += f11
    

    So some simple test data:

    a = 2
    l = 5
    r = 0.2
    wp = (numpy.arange(2**l)[:,None] >> numpy.arange(l)[::-1]) & 1
    wp = numpy.hstack([wp.sum(1,keepdims=True), wp])
    ws = wp[:, 3:l+3]
    re = numpy.zeros((a**l, a**l, a**l))
    

    and now let's check that all three functions produce the same result:

    re = numpy.zeros((a**l, a**l, a**l))
    func(re, ws, a, l, r)
    
    re2 = numpy.zeros((a**l, a**l, a**l))
    func2(re2, ws, a, l, r)
    
    re3 = numpy.zeros((a**l, a**l, a**l))
    func3(re3, ws, a, l, r)
    
    print np.allclose(re, re2)  # True
    print np.allclose(re, re3)  # True
    

    And some initial timings using the jupyter notebook %timeit magic:

    %timeit func(re, ws, a, l, r)
    %timeit func2(re2, ws, a, l, r)
    %timeit func3(re3, ws, a, l, r)
    
    1 loop, best of 3: 404 ms per loop
    100 loops, best of 3: 14.2 ms per loop
    1000 loops, best of 3: 605 µs per loop
    

    func2 is ~28x times faster than the original implementation. func3 is ~680x faster. Note that I'm running on a Macbook laptop with an i7 processor, 16 GB of RAM and using Numba 0.25.0.

    Ok, so now let's time the a=2 l=10 case that everyone is wringing their hands about:

    a = 2
    l = 10
    r = 0.2
    wp = (numpy.arange(2**l)[:,None] >> numpy.arange(l)[::-1]) & 1
    wp = numpy.hstack([wp.sum(1,keepdims=True), wp])
    ws = wp[:, 3:l+3]
    re = numpy.zeros((a**l, a**l, a**l))
    print 'setup complete'
    
    %timeit -n 1 -r 1 func3(re, ws, a, l, r)
    
    # setup complete
    # 1 loop, best of 1: 45.4 s per loop
    

    So this took 45 seconds on my machine single threaded, which seems reasonable if you aren't then doing this one calculation too many times.

提交回复
热议问题