Index n dimensional array with (n-1) d array

前端 未结 1 1049
醉梦人生
醉梦人生 2020-11-22 14:17

What is the most elegant way to access an n dimensional array with an (n-1) dimensional array along a given dimension as in the dummy example

a = np.random.r         


        
相关标签:
1条回答
  • 2020-11-22 15:14

    Make use of advanced-indexing -

    m,n = a.shape[1:]
    I,J = np.ogrid[:m,:n]
    a_max_values = a[idx, I, J]
    b_max_values = b[idx, I, J]
    

    For the general case:

    def argmax_to_max(arr, argmax, axis):
        """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)"""
        new_shape = list(arr.shape)
        del new_shape[axis]
    
        grid = np.ogrid[tuple(map(slice, new_shape))]
        grid.insert(axis, argmax)
    
        return arr[tuple(grid)]
    

    Quite a bit more awkward than such a natural operation should be, unfortunately.

    For indexing a n dim array with a (n-1) dim array, we could simplify it a bit to give us the grid of indices for all axes, like so -

    def all_idx(idx, axis):
        grid = np.ogrid[tuple(map(slice, idx.shape))]
        grid.insert(axis, idx)
        return tuple(grid)
    

    Hence, use it to index into input arrays -

    axis = 0
    a_max_values = a[all_idx(idx, axis=axis)]
    b_max_values = b[all_idx(idx, axis=axis)]
    
    0 讨论(0)
提交回复
热议问题