How do I get indices of N maximum values in a NumPy array?

后端 未结 18 1237
长情又很酷
长情又很酷 2020-11-22 04:25

NumPy proposes a way to get the index of the maximum value of an array via np.argmax.

I would like a similar thing, but returning the indexes of the

18条回答
  •  你的背包
    2020-11-22 04:47

    For multidimensional arrays you can use the axis keyword in order to apply the partitioning along the expected axis.

    # For a 2D array
    indices = np.argpartition(arr, -N, axis=1)[:, -N:]
    

    And for grabbing the items:

    x = arr.shape[0]
    arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)
    

    But note that this won't return a sorted result. In that case you can use np.argsort() along the intended axis:

    indices = np.argsort(arr, axis=1)[:, -N:]
    
    # Result
    x = arr.shape[0]
    arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)
    

    Here is an example:

    In [42]: a = np.random.randint(0, 20, (10, 10))
    
    In [44]: a
    Out[44]:
    array([[ 7, 11, 12,  0,  2,  3,  4, 10,  6, 10],
           [16, 16,  4,  3, 18,  5, 10,  4, 14,  9],
           [ 2,  9, 15, 12, 18,  3, 13, 11,  5, 10],
           [14,  0,  9, 11,  1,  4,  9, 19, 18, 12],
           [ 0, 10,  5, 15,  9, 18,  5,  2, 16, 19],
           [14, 19,  3, 11, 13, 11, 13, 11,  1, 14],
           [ 7, 15, 18,  6,  5, 13,  1,  7,  9, 19],
           [11, 17, 11, 16, 14,  3, 16,  1, 12, 19],
           [ 2,  4, 14,  8,  6,  9, 14,  9,  1,  5],
           [ 1, 10, 15,  0,  1,  9, 18,  2,  2, 12]])
    
    In [45]: np.argpartition(a, np.argmin(a, axis=0))[:, 1:] # 1 is because the first item is the minimum one.
    Out[45]:
    array([[4, 5, 6, 8, 0, 7, 9, 1, 2],
           [2, 7, 5, 9, 6, 8, 1, 0, 4],
           [5, 8, 1, 9, 7, 3, 6, 2, 4],
           [4, 5, 2, 6, 3, 9, 0, 8, 7],
           [7, 2, 6, 4, 1, 3, 8, 5, 9],
           [2, 3, 5, 7, 6, 4, 0, 9, 1],
           [4, 3, 0, 7, 8, 5, 1, 2, 9],
           [5, 2, 0, 8, 4, 6, 3, 1, 9],
           [0, 1, 9, 4, 3, 7, 5, 2, 6],
           [0, 4, 7, 8, 5, 1, 9, 2, 6]])
    
    In [46]: np.argpartition(a, np.argmin(a, axis=0))[:, -3:]
    Out[46]:
    array([[9, 1, 2],
           [1, 0, 4],
           [6, 2, 4],
           [0, 8, 7],
           [8, 5, 9],
           [0, 9, 1],
           [1, 2, 9],
           [3, 1, 9],
           [5, 2, 6],
           [9, 2, 6]])
    
    In [89]: a[np.repeat(np.arange(x), 3), ind.ravel()].reshape(x, 3)
    Out[89]:
    array([[10, 11, 12],
           [16, 16, 18],
           [13, 15, 18],
           [14, 18, 19],
           [16, 18, 19],
           [14, 14, 19],
           [15, 18, 19],
           [16, 17, 19],
           [ 9, 14, 14],
           [12, 15, 18]])
    

提交回复
热议问题