Efficiently count zero elements in numpy array?

前端 未结 1 890
北海茫月
北海茫月 2021-02-03 19:18

I need to count the number of zero elements in numpy arrays. I\'m aware of the numpy.count_nonzero function, but there appears to be no analog for counting zero ele

相关标签:
1条回答
  • 2021-02-03 19:51

    A 2x faster approach would be to just use np.count_nonzero() but with the condition as needed.

    In [3]: arr
    Out[3]: 
    array([[1, 2, 0, 3],
          [3, 9, 0, 4]])
    
    In [4]: np.count_nonzero(arr==0)
    Out[4]: 2
    
    In [5]:def func_cnt():
                for arr in arrs:
                    zero_els = np.count_nonzero(arr==0)
                    # here, it counts the frequency of zeroes actually
    

    You can also use np.where() but it's slower than np.count_nonzero()

    In [6]: np.where( arr == 0)
    Out[6]: (array([0, 1]), array([2, 2]))
    
    In [7]: len(np.where( arr == 0))
    Out[7]: 2
    

    Efficiency: (in descending order)

    In [8]: %timeit func_cnt()
    10 loops, best of 3: 29.2 ms per loop
    
    In [9]: %timeit func1()
    10 loops, best of 3: 46.5 ms per loop
    
    In [10]: %timeit func_where()
    10 loops, best of 3: 61.2 ms per loop
    

    more speedups with accelerators

    It is now possible to achieve more than 3 orders of magnitude speed boost with the help of JAX if you've access to accelerators (GPU/TPU). Another advantage of using JAX is that the NumPy code needs very little modification to make it JAX compatible. Below is a reproducible example:

    In [1]: import jax.numpy as jnp
    In [2]: from jax import jit
    
    # set up inputs
    In [3]: arrs = []
    In [4]: for _ in range(1000):
       ...:     arrs.append(np.random.randint(-5, 5, 10000))
    
    # JIT'd function that performs the counting task
    In [5]: @jit
       ...: def func_cnt():
       ...:     for arr in arrs:
       ...:         zero_els = jnp.count_nonzero(arr==0)
    

    # efficiency test
    In [8]: %timeit func_cnt()
    15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    
    0 讨论(0)
提交回复
热议问题