From the answer to this question (Sort a numpy array by another array, along a particular axis, using less memory), I learned how to sort a multidimensional numpy array
Use np.take
with the axis
keyword argument:
>>> a = np.arange(2*3*4).reshape(2, 3, 4)
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> b = np.arange(3)
>>> np.random.shuffle(b)
>>> b
array([1, 0, 2])
>>> np.take(a, b, axis=1)
array([[[ 4, 5, 6, 7],
[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
If you want to use fancy indexing, you just need to pad the indexing tuple with enough empty slices:
>>> a[:, b]
array([[[ 4, 5, 6, 7],
[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
Or in a more general setting:
>>> axis = 1
>>> idx = (slice(None),) * axis + (b,)
>>> a[idx]
array([[[ 4, 5, 6, 7],
[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
But np.take
should really be your first option.