I\'m trying to implement a numpy function that replaces the max in each row of a 2D array with 1, and all other numbers with zero:
>>> a = np.array(
a==np.max(a)
will raise an error in the future, so here's a tweaked version that will continue to broadcast correctly.
I know this question is pretty ancient, but I think I have a decent solution that's a bit different from the other solutions.
# get max by row and convert from (n, ) -> (n, 1) which will broadcast
row_maxes = a.max(axis=1).reshape(-1, 1)
np.where(a == row_maxes, 1, 0)
np.where(a == row_maxes).astype(int)
if the update needs to be in place, you can do
a[:] = np.where(a == row_maxes, 1, 0)
Method #1, tweaking yours:
>>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])
>>> b = np.zeros_like(a)
>>> b[np.arange(len(a)), a.argmax(1)] = 1
>>> b
array([[0, 1],
[0, 1],
[0, 1],
[0, 1],
[1, 0]])
[Actually, range
will work just fine; I wrote arange
out of habit.]
Method #2, using max
instead of argmax
to handle the case where multiple elements reach the maximum value:
>>> a = np.array([[0, 1], [2, 2], [4, 3]])
>>> (a == a.max(axis=1)[:,None]).astype(int)
array([[0, 1],
[1, 1],
[1, 0]])
I prefer using numpy.where like so:
a[np.where(a==np.max(a))] = 1