Numpy: How to find first non-zero value in every column of a numpy array?

前端 未结 1 970
粉色の甜心
粉色の甜心 2020-11-27 06:06

Suppose I have a numpy array of the form:

arr=numpy.array([[1,1,0],[1,1,0],[0,0,1],[0,0,0]])

I want to find the indices of the first index

相关标签:
1条回答
  • 2020-11-27 06:34

    Indices of first occurrences

    Use np.argmax along that axis (zeroth axis for columns here) on the mask of non-zeros to get the indices of first matches (True values) -

    (arr!=0).argmax(axis=0)
    

    Extending to cover generic axis specifier and for cases where no non-zeros are found along that axis for an element, we would have an implementation like so -

    def first_nonzero(arr, axis, invalid_val=-1):
        mask = arr!=0
        return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)
    

    Note that since argmax() on all False values returns 0, so if the invalid_val needed is 0, we would have the final output directly with mask.argmax(axis=axis).

    Sample runs -

    In [296]: arr    # Different from given sample for variety
    Out[296]: 
    array([[1, 0, 0],
           [1, 1, 0],
           [0, 1, 0],
           [0, 0, 0]])
    
    In [297]: first_nonzero(arr, axis=0, invalid_val=-1)
    Out[297]: array([ 0,  1, -1])
    
    In [298]: first_nonzero(arr, axis=1, invalid_val=-1)
    Out[298]: array([ 0,  0,  1, -1])
    

    Extending to cover all comparison operations

    To find the first zeros, simply use arr==0 as mask for use in the function. For first ones equal to a certain value val, use arr == val and so on for all cases of comparisons possible here.


    Indices of last occurrences

    To find the last ones matching a certain comparison criteria, we need to flip along that axis and use the same idea of using argmax and then compensate for the flipping by offsetting from the axis length, as shown below -

    def last_nonzero(arr, axis, invalid_val=-1):
        mask = arr!=0
        val = arr.shape[axis] - np.flip(mask, axis=axis).argmax(axis=axis) - 1
        return np.where(mask.any(axis=axis), val, invalid_val)
    

    Sample runs -

    In [320]: arr
    Out[320]: 
    array([[1, 0, 0],
           [1, 1, 0],
           [0, 1, 0],
           [0, 0, 0]])
    
    In [321]: last_nonzero(arr, axis=0, invalid_val=-1)
    Out[321]: array([ 1,  2, -1])
    
    In [322]: last_nonzero(arr, axis=1, invalid_val=-1)
    Out[322]: array([ 0,  1,  1, -1])
    

    Again, all cases of comparisons possible here are covered by using the corresponding comparator to get mask and then using within the listed function.

    0 讨论(0)
提交回复
热议问题