问题
I have N-dimensional matrix which contains the values for a function with N parameters. Each parameter has a discrete number of values. I need to maximize the function over all parameters but one, resulting in a one-dimensional vector of size equal to the number of values of the non-maximized parameter. I also need to save which values are taken by the other parameters.
To do so I wanted to iteratively apply numpy.max
over different axes to reduce the dimensionality of the matrix to find what I need. The final vector will then depend on just the parameter I left out.
I'm however having trouble finding the original indices of the final elements (which contain the information about the values taken by the other parameters). I though about using numpy.argmax
in the same way as numpy.max
but I can't obtain back the original indices.
An example of what I'm trying is:
x = [[[1,2],[0,1]],[[3,4],[6,7]]]
args = np.argmax(x, 0)
This returns
[[1 1]
[1 1]]
Which means that argmax is selecting the elements (2,1,4,7) within the original matrix. But how to get their indices? I tried unravel_index
, using the args
directly as an index for matrix x
, a bunch of functions from numpy to index with no success.
Using numpy.where
is not a solution since the input matrix may have equal values inside, so I would not be able to discern from different original values.
回答1:
x.argmax(0)
gives the indexes along the 1st axis for the maximum values. Use np.indices
to generate the indices for the other axis.
x = np.array([[[1,2],[0,1]],[[3,4],[6,7]]])
x.argmax(0)
array([[1, 1],
[1, 1]])
a1, a2 = np.indices((2,2))
(x.argmax(0),a1,a2)
(array([[1, 1],
[1, 1]]),
array([[0, 0],
[1, 1]]),
array([[0, 1],
[0, 1]]))
x[x.argmax(0),a1,a2]
array([[3, 4],
[6, 7]])
x[a1,x.argmax(1),a2]
array([[1, 2],
[6, 7]])
x[a1,a2,x.argmax(2)]
array([[2, 1],
[4, 7]])
If x
has other dimensions, generate a1
, and a2
appropriately.
The official documentation does not say much about how to use argmax
, but earlier SO threads have discussed it. I got this general idea from Using numpy.argmax() on multidimensional arrays
来源:https://stackoverflow.com/questions/20128837/get-indices-of-numpy-argmax-elements-over-an-axis