I have an ndarray where each row is a separate histogram. For each row, I wish to find the top N values.
I am aware of a solution for the global top N values (A fas
You can use np.argsort along the rows with axis = 1
like so -
import numpy as np
# Find sorted indices for each row
sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]
# Setup column indexing array
col_idx = np.arange(A.shape[0])[:,None]
# Use the column-row indices to get specific elements from input array.
# Please note that since the column indexing array isn't of the same shape
# as the sorted row indices, it will be broadcasted
out = A[col_idx,sorted_row_idx]
Sample run -
In [417]: A
Out[417]:
array([[0, 3, 3, 2, 5],
[4, 2, 6, 3, 1],
[2, 1, 1, 8, 8],
[6, 6, 3, 2, 6]])
In [418]: N
Out[418]: 3
In [419]: sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]
In [420]: sorted_row_idx
Out[420]:
array([[1, 2, 4],
[3, 0, 2],
[0, 3, 4],
[0, 1, 4]], dtype=int64)
In [421]: col_idx = np.arange(A.shape[0])[:,None]
In [422]: col_idx
Out[422]:
array([[0],
[1],
[2],
[3]])
In [423]: out = A[col_idx,sorted_row_idx]
In [424]: out
Out[424]:
array([[3, 3, 5],
[3, 4, 6],
[2, 8, 8],
[6, 6, 6]])
If you would like to have the elements in descending order, you can use this additional step -
In [425]: out[:,::-1]
Out[425]:
array([[5, 3, 3],
[6, 4, 3],
[8, 8, 2],
[6, 6, 6]])