Get indices of numpy.argmax elements over an axis

泄露秘密 提交于 2019-12-10 03:48:58

问题


I have N-dimensional matrix which contains the values for a function with N parameters. Each parameter has a discrete number of values. I need to maximize the function over all parameters but one, resulting in a one-dimensional vector of size equal to the number of values of the non-maximized parameter. I also need to save which values are taken by the other parameters.

To do so I wanted to iteratively apply numpy.max over different axes to reduce the dimensionality of the matrix to find what I need. The final vector will then depend on just the parameter I left out.

I'm however having trouble finding the original indices of the final elements (which contain the information about the values taken by the other parameters). I though about using numpy.argmax in the same way as numpy.max but I can't obtain back the original indices.

An example of what I'm trying is:

x = [[[1,2],[0,1]],[[3,4],[6,7]]]
args = np.argmax(x, 0)

This returns

[[1 1]
 [1 1]]

Which means that argmax is selecting the elements (2,1,4,7) within the original matrix. But how to get their indices? I tried unravel_index, using the args directly as an index for matrix x, a bunch of functions from numpy to index with no success.

Using numpy.where is not a solution since the input matrix may have equal values inside, so I would not be able to discern from different original values.


回答1:


x.argmax(0) gives the indexes along the 1st axis for the maximum values. Use np.indices to generate the indices for the other axis.

x = np.array([[[1,2],[0,1]],[[3,4],[6,7]]])
x.argmax(0)
    array([[1, 1],
           [1, 1]])
a1, a2 = np.indices((2,2))
(x.argmax(0),a1,a2)
    (array([[1, 1],
            [1, 1]]),
     array([[0, 0],
            [1, 1]]),
     array([[0, 1],
            [0, 1]]))


x[x.argmax(0),a1,a2]
    array([[3, 4],
           [6, 7]])

x[a1,x.argmax(1),a2] 
    array([[1, 2],
           [6, 7]])

x[a1,a2,x.argmax(2)] 
    array([[2, 1],
           [4, 7]])

If x has other dimensions, generate a1, and a2 appropriately.

The official documentation does not say much about how to use argmax, but earlier SO threads have discussed it. I got this general idea from Using numpy.argmax() on multidimensional arrays



来源:https://stackoverflow.com/questions/20128837/get-indices-of-numpy-argmax-elements-over-an-axis

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!