Swap zeros in numpy matrix

后端 未结 6 1162
眼角桃花
眼角桃花 2021-01-20 13:08

I have a numpy matrix like so:

array([[2,  1, 23, 32],
       [34, 3, 3, 0],
       [3, 33, 0, 0],
       [32, 0, 0, 0]], dtype=int32)

Now

6条回答
  •  傲寒
    傲寒 (楼主)
    2021-01-20 13:45

    Here's a vectorized approach with masking -

    valid_mask = a!=0
    flipped_mask = valid_mask.sum(1,keepdims=1) > np.arange(a.shape[1]-1,-1,-1)
    a[flipped_mask] = a[valid_mask]
    a[~flipped_mask] = 0
    

    Sample run -

    In [90]: a
    Out[90]: 
    array([[ 2,  1, 23, 32],
           [34,  0,  3,  0],  # <== Added a zero in between for variety
           [ 3, 33,  0,  0],
           [32,  0,  0,  0]])
    
    # After code run -
    
    In [92]: a
    Out[92]: 
    array([[ 2,  1, 23, 32],
           [ 0,  0, 34,  3],
           [ 0,  0,  3, 33],
           [ 0,  0,  0, 32]])
    

    One more generic sample run -

    In [94]: a
    Out[94]: 
    array([[1, 1, 2, 3, 1, 0, 3, 0, 2, 1],
           [2, 1, 0, 1, 2, 0, 1, 3, 1, 1],
           [1, 2, 0, 3, 0, 3, 2, 0, 2, 2]])
    
    # After code run -
    
    In [96]: a
    Out[96]: 
    array([[0, 0, 1, 1, 2, 3, 1, 3, 2, 1],
           [0, 0, 2, 1, 1, 2, 1, 3, 1, 1],
           [0, 0, 0, 1, 2, 3, 3, 2, 2, 2]])
    

    Runtime test

    Approaches that work on generic cases -

    # Proposed in this post
    def masking_based(a):
        valid_mask = a!=0
        flipped_mask = valid_mask.sum(1,keepdims=1) > np.arange(a.shape[1]-1,-1,-1)
        a[flipped_mask] = a[valid_mask]
        a[~flipped_mask] = 0
        return a
    
    # @Psidom's soln            
    def sort_based(a):
        return a[np.arange(a.shape[0])[:, None], (a != 0).argsort(1, kind="mergesort")]
    

    Timings -

    In [205]: a = np.random.randint(0,4,(1000,1000))
    
    In [206]: %timeit sort_based(a)
    10 loops, best of 3: 30.8 ms per loop
    
    In [207]: %timeit masking_based(a)
    100 loops, best of 3: 6.46 ms per loop
    
    In [208]: a = np.random.randint(0,4,(5000,5000))
    
    In [209]: %timeit sort_based(a)
    1 loops, best of 3: 961 ms per loop
    
    In [210]: %timeit masking_based(a)
    1 loops, best of 3: 151 ms per loop
    

提交回复
热议问题