Why Numba doesn't improve this recursive function

后端 未结 3 2108
有刺的猬
有刺的猬 2021-01-26 06:28

I have an array of true/false values with a very simple structure:

# the real array has hundreds of thousands of items
pos         


        
3条回答
  •  隐瞒了意图╮
    2021-01-26 07:10

    You can find the positions of value changes by using np.diff, there is no need to run a more complicated algorithm, or to use numba:

    positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
    dpos = np.diff(positions)
    # array([ True, False, False,  True, False, False, False,  True, False, False])
    

    This works, because False - True == -1 and np.bool(-1) == True.

    It performs quite well on my battery powered (= throttled due to energy saving mode), and several years old laptop:

    In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)          
    
    In [53]: %timeit np.diff(positions)                                             
    633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    I'd imagine that writing your own diff in numba should yield similar performance.

    EDIT: The last statement is false, I implemented a simple diff function using numba, and it's more than a factor of 10 faster than the numpy one (but it obviously also has much less features, but should be sufficient for this task):

    @numba.njit 
    def ndiff(x): 
        s = x.size - 1 
        r = np.empty(s, dtype=x.dtype) 
        for i in range(s): 
            r[i] = x[i+1] - x[i] 
        return r
    
    In [68]: np.all(ndiff(positions) == np.diff(positions))                            
    Out[68]: True
    
    In [69]: %timeit ndiff(positions)                                               
    46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    

提交回复
热议问题