Numpy: change max in each row to 1, all other numbers to 0

后端 未结 3 1699
隐瞒了意图╮
隐瞒了意图╮ 2020-12-15 03:32

I\'m trying to implement a numpy function that replaces the max in each row of a 2D array with 1, and all other numbers with zero:

>>> a = np.array(         


        
相关标签:
3条回答
  • 2020-12-15 04:14

    a==np.max(a) will raise an error in the future, so here's a tweaked version that will continue to broadcast correctly.

    I know this question is pretty ancient, but I think I have a decent solution that's a bit different from the other solutions.

    # get max by row and convert from (n, ) -> (n, 1) which will broadcast
    row_maxes = a.max(axis=1).reshape(-1, 1)
    np.where(a == row_maxes, 1, 0)
    np.where(a == row_maxes).astype(int)
    

    if the update needs to be in place, you can do

    a[:] = np.where(a == row_maxes, 1, 0)
    
    0 讨论(0)
  • 2020-12-15 04:18

    Method #1, tweaking yours:

    >>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])
    >>> b = np.zeros_like(a)
    >>> b[np.arange(len(a)), a.argmax(1)] = 1
    >>> b
    array([[0, 1],
           [0, 1],
           [0, 1],
           [0, 1],
           [1, 0]])
    

    [Actually, range will work just fine; I wrote arange out of habit.]

    Method #2, using max instead of argmax to handle the case where multiple elements reach the maximum value:

    >>> a = np.array([[0, 1], [2, 2], [4, 3]])
    >>> (a == a.max(axis=1)[:,None]).astype(int)
    array([[0, 1],
           [1, 1],
           [1, 0]])
    
    0 讨论(0)
  • 2020-12-15 04:20

    I prefer using numpy.where like so:

    a[np.where(a==np.max(a))] = 1
    
    0 讨论(0)
提交回复
热议问题