numpy 3d to 2d transformation based on 2d mask array

前端 未结 2 357
栀梦
栀梦 2021-01-19 08:36

If I have an ndarray like this:

>>> a = np.arange(27).reshape(3,3,3)
>>> a
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]         


        
相关标签:
2条回答
  • 2021-01-19 09:17

    Here is some magic numpy indexing that will do what you want, but unfortunately it's pretty unreadable.

    def apply_mask(a, indices, axis):
        magic_index = [np.arange(i) for i in indices.shape]
        magic_index = np.ix_(*magic_index)
        magic_index = magic_index[:axis] + (indices,) + magic_index[axis:]
        return a[magic_index]
    

    or equally unreadable:

    def apply_mask(a, indices, axis):
        magic_index = np.ogrid[tuple(slice(i) for i in indices.shape)]
        magic_index.insert(axis, indices)
        return a[magic_index]
    
    0 讨论(0)
  • 2021-01-19 09:20

    I use index_at() to create the full index:

    import numpy as np
    
    def index_at(idx, shape, axis=-1):
        if axis<0:
            axis += len(shape)
        shape = shape[:axis] + shape[axis+1:]
        index = list(np.ix_(*[np.arange(n) for n in shape]))
        index.insert(axis, idx)
        return tuple(index)
    
    a = np.random.randint(0, 10, (3, 4, 5))
    
    axis = 1
    idx = np.argmax(a, axis=axis)
    print a[index_at(idx, a.shape, axis=axis)]
    print np.max(a, axis=axis)
    
    0 讨论(0)
提交回复
热议问题