I need to count the number of zero elements in numpy
arrays. I\'m aware of the numpy.count_nonzero function, but there appears to be no analog for counting zero ele
A 2x faster approach would be to just use np.count_nonzero() but with the condition as needed.
In [3]: arr
Out[3]:
array([[1, 2, 0, 3],
[3, 9, 0, 4]])
In [4]: np.count_nonzero(arr==0)
Out[4]: 2
In [5]:def func_cnt():
for arr in arrs:
zero_els = np.count_nonzero(arr==0)
# here, it counts the frequency of zeroes actually
You can also use np.where() but it's slower than np.count_nonzero()
In [6]: np.where( arr == 0)
Out[6]: (array([0, 1]), array([2, 2]))
In [7]: len(np.where( arr == 0))
Out[7]: 2
Efficiency: (in descending order)
In [8]: %timeit func_cnt()
10 loops, best of 3: 29.2 ms per loop
In [9]: %timeit func1()
10 loops, best of 3: 46.5 ms per loop
In [10]: %timeit func_where()
10 loops, best of 3: 61.2 ms per loop
more speedups with accelerators
It is now possible to achieve more than 3 orders of magnitude speed boost with the help of JAX if you've access to accelerators (GPU/TPU). Another advantage of using JAX is that the NumPy code needs very little modification to make it JAX compatible. Below is a reproducible example:
In [1]: import jax.numpy as jnp
In [2]: from jax import jit
# set up inputs
In [3]: arrs = []
In [4]: for _ in range(1000):
...: arrs.append(np.random.randint(-5, 5, 10000))
# JIT'd function that performs the counting task
In [5]: @jit
...: def func_cnt():
...: for arr in arrs:
...: zero_els = jnp.count_nonzero(arr==0)
# efficiency test
In [8]: %timeit func_cnt()
15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)