If I have an ndarray like this:
>>> a = np.arange(27).reshape(3,3,3)
>>> a
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]
Here is some magic numpy indexing that will do what you want, but unfortunately it's pretty unreadable.
def apply_mask(a, indices, axis):
magic_index = [np.arange(i) for i in indices.shape]
magic_index = np.ix_(*magic_index)
magic_index = magic_index[:axis] + (indices,) + magic_index[axis:]
return a[magic_index]
or equally unreadable:
def apply_mask(a, indices, axis):
magic_index = np.ogrid[tuple(slice(i) for i in indices.shape)]
magic_index.insert(axis, indices)
return a[magic_index]