N largest values in each row of ndarray

后端 未结 2 1150
独厮守ぢ
独厮守ぢ 2021-01-12 20:31

I have an ndarray where each row is a separate histogram. For each row, I wish to find the top N values.

I am aware of a solution for the global top N values (A fas

2条回答
  •  北荒
    北荒 (楼主)
    2021-01-12 21:15

    You can use np.argsort along the rows with axis = 1 like so -

    import numpy as np
    
    # Find sorted indices for each row
    sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]
    
    # Setup column indexing array
    col_idx = np.arange(A.shape[0])[:,None]
    
    # Use the column-row indices to get specific elements from input array. 
    # Please note that since the column indexing array isn't of the same shape 
    # as the sorted row indices, it will be broadcasted
    out = A[col_idx,sorted_row_idx]
    

    Sample run -

    In [417]: A
    Out[417]: 
    array([[0, 3, 3, 2, 5],
           [4, 2, 6, 3, 1],
           [2, 1, 1, 8, 8],
           [6, 6, 3, 2, 6]])
    
    In [418]: N
    Out[418]: 3
    
    In [419]: sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]
    
    In [420]: sorted_row_idx
    Out[420]: 
    array([[1, 2, 4],
           [3, 0, 2],
           [0, 3, 4],
           [0, 1, 4]], dtype=int64)
    
    In [421]: col_idx = np.arange(A.shape[0])[:,None]
    
    In [422]: col_idx
    Out[422]: 
    array([[0],
           [1],
           [2],
           [3]])
    
    In [423]: out = A[col_idx,sorted_row_idx]
    
    In [424]: out
    Out[424]: 
    array([[3, 3, 5],
           [3, 4, 6],
           [2, 8, 8],
           [6, 6, 6]])
    

    If you would like to have the elements in descending order, you can use this additional step -

    In [425]: out[:,::-1]
    Out[425]: 
    array([[5, 3, 3],
           [6, 4, 3],
           [8, 8, 2],
           [6, 6, 6]])
    

提交回复
热议问题