Numba code slower than pure python

前端 未结 2 1562
栀梦
栀梦 2021-02-01 06:54

I\'ve been working on speeding up a resampling calculation for a particle filter. As python has many ways to speed it up, I though I\'d try them all. Unfortunately, the numba ve

相关标签:
2条回答
  • 2021-02-01 07:31

    The problem is that numba can't intuit the type of lookup. If you put a print nb.typeof(lookup) in your method, you'll see that numba is treating it as an object, which is slow. Normally I would just define the type of lookup in a locals dict, but I was getting a strange error. Instead I just created a little wrapper, so that I could explicitly define the input and output types.

    @nb.jit(nb.f8[:](nb.f8[:]))
    def numba_cumsum(x):
        return np.cumsum(x)
    
    @nb.autojit
    def numba_resample2(qs, xs, rands):
        n = qs.shape[0]
        #lookup = np.cumsum(qs)
        lookup = numba_cumsum(qs)
        results = np.empty(n)
    
        for j in range(n):
            for i in range(n):
                if rands[j] < lookup[i]:
                    results[j] = xs[i]
                    break
        return results
    

    Then my timings are:

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    
    print "Timing Revised Numba Function:"
    %timeit numba_resample2(qs, xs, rands)
    

    Timing Numba Function:
    100 loops, best of 3: 8.1 ms per loop
    Timing Revised Numba Function:
    100000 loops, best of 3: 15.3 µs per loop
    

    You can go even a little faster still if you use jit instead of autojit:

    @nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))
    

    For me that lowers it from 15.3 microseconds to 12.5 microseconds, but it's still impressive how well autojit does.

    0 讨论(0)
  • 2021-02-01 07:33

    Faster numpy version (10x speedup compared to numpy_resample)

    def numpy_faster(qs, xs, rands):
        lookup = np.cumsum(qs)
        mm = lookup[None,:]>rands[:,None]
        I = np.argmax(mm,1)
        return xs[I]
    
    0 讨论(0)
提交回复
热议问题