Remove numpy rows contained in a list?

前端 未结 1 1892
无人及你
无人及你 2021-01-21 08:19

I have a numpy array and a list. I want to remove the rows contained in the list.

a = np.zeros((3, 2))
a[0, :] = [1, 2]
l = [(1, 2), (3, 4)]

Cu

相关标签:
1条回答
  • 2021-01-21 08:55

    Approach #1 : Here's one with views (viewing each row as an element each with extended dtype) -

    # https://stackoverflow.com/a/45313353/ @Divakar
    def view1D(a, b): # a, b are arrays
        a = np.ascontiguousarray(a)
        b = np.ascontiguousarray(b)
        void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
        return a.view(void_dt).ravel(),  b.view(void_dt).ravel()
    
    a1D,l1D = view1D(a,l)
    out = a[np.in1d(a1D,l1D,invert=True)]
    

    If you need to have unique rows only in the output as with set, use np.unique on the output obtained -

    np.unique(out,axis=0)
    

    Sample run outputs -

    In [72]: a
    Out[72]: 
    array([[1, 2],
           [0, 0],
           [0, 0]])
    
    In [73]: l
    Out[73]: [(1, 2), (3, 4)]
    
    In [74]: out
    Out[74]: 
    array([[0, 0],
           [0, 0]])
    In [75]: np.unique(out,axis=0)
    Out[75]: array([[0, 0]])
    

    Approach #2 : With the same philosophy of reducing dimensionality, here's with matrix-multiplication specific to int dtype data -

    l = np.asarray(l)
    shp = np.maximum(a.max(0)+1,l.max(0)+1)
    s = np.r_[shp[::-1].cumprod()[::-1][1:],1]
    l1D = l.dot(s)
    a1D = a.dot(s)
    l1Ds = np.sort(l1D)
    out = a[l1D[np.searchsorted(l1Ds,a1D)] != a1D]
    
    0 讨论(0)
提交回复
热议问题