Why Numba doesn't improve this recursive function

后端 未结 3 2092
有刺的猬
有刺的猬 2021-01-26 06:28

I have an array of true/false values with a very simple structure:

# the real array has hundreds of thousands of items
pos         


        
相关标签:
3条回答
  • 2021-01-26 07:08

    The main issue is that you are not performing an apple-to-apple comparison. What you provide is not an iterative and a recursive version of the same algorithm. You are proposing two fundamentally different algorithms, which happen to be recursive/iterative.

    In particular you are using NumPy built-ins a lot more in the recursive approach, so no wonder that there is such a staggering difference in the two approaches. It should also come at no surprise that the Numba JITting is more effective when you are avoiding NumPy built-ins. Eventually, the recursive algorithm seems to be less efficient as there is some hidden nested looping in the np.all() and np.any() calls that the iterative approach is avoiding, so even if you were to write all your code in pure Python to be accelerated with Numba more effectively, the recursive approach would be slower.

    In general, iterative approaches are faster then the recursive equivalent, because they avoid the function call overhead (which is minimal for JIT accelerated functions compared to pure Python ones). So I would advise against trying to rewrite the algorithm in recursive form, only to discover that it is slower.


    EDIT

    On the premises that a simple np.diff() would do the trick, Numba can still be quite beneficial:

    import numpy as np
    import numba as nb
    
    
    @nb.jit
    def diff(arr):
        n = arr.size
        result = np.empty(n - 1, dtype=arr.dtype)
        for i in range(n - 1):
            result[i] = arr[i + 1] ^ arr[i]
        return result
    
    
    positions = np.random.randint(0, 2, size=300_000, dtype=bool)
    print(np.allclose(np.diff(positions), diff(positions)))
    # True
    
    
    %timeit np.diff(positions)
    # 1000 loops, best of 3: 603 µs per loop
    %timeit diff(positions)
    # 10000 loops, best of 3: 43.3 µs per loop
    

    with the Numba approach being some 13x faster (in this test, mileage may vary, of course).

    0 讨论(0)
  • 2021-01-26 07:10

    You can find the positions of value changes by using np.diff, there is no need to run a more complicated algorithm, or to use numba:

    positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
    dpos = np.diff(positions)
    # array([ True, False, False,  True, False, False, False,  True, False, False])
    

    This works, because False - True == -1 and np.bool(-1) == True.

    It performs quite well on my battery powered (= throttled due to energy saving mode), and several years old laptop:

    In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)          
    
    In [53]: %timeit np.diff(positions)                                             
    633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    I'd imagine that writing your own diff in numba should yield similar performance.

    EDIT: The last statement is false, I implemented a simple diff function using numba, and it's more than a factor of 10 faster than the numpy one (but it obviously also has much less features, but should be sufficient for this task):

    @numba.njit 
    def ndiff(x): 
        s = x.size - 1 
        r = np.empty(s, dtype=x.dtype) 
        for i in range(s): 
            r[i] = x[i+1] - x[i] 
        return r
    
    In [68]: np.all(ndiff(positions) == np.diff(positions))                            
    Out[68]: True
    
    In [69]: %timeit ndiff(positions)                                               
    46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    0 讨论(0)
  • 2021-01-26 07:23

    The gist is, only the part of logic that uses Python machinery can be accelerated -- by replacing it with some equivalent C logic that strips away most of the complexity (and flexibility) of Python runtime (I presume this is what Numba does).

    All the heavy lifting in NumPy operations is already implemented in C and very simple (since NumPy arrays are contiguous chunks of memory holding regular C types) so Numba can only strip the parts that interface with Python machinery.

    Your "binary search" algorithm does much more work and makes much heavier use of NumPy's vector operations while at it, so less of it can be accelerated this way.

    0 讨论(0)
提交回复
热议问题