Find unique rows in numpy.array

后端 未结 20 2852
独厮守ぢ
独厮守ぢ 2020-11-21 10:57

I need to find unique rows in a numpy.array.

For example:

>>> a # I have
array([[1, 1, 1, 0, 0, 0],
       [0, 1, 1, 1, 0, 0],
         


        
20条回答
  •  灰色年华
    2020-11-21 11:47

    If you want to avoid the memory expense of converting to a series of tuples or another similar data structure, you can exploit numpy's structured arrays.

    The trick is to view your original array as a structured array where each item corresponds to a row of the original array. This doesn't make a copy, and is quite efficient.

    As a quick example:

    import numpy as np
    
    data = np.array([[1, 1, 1, 0, 0, 0],
                     [0, 1, 1, 1, 0, 0],
                     [0, 1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0, 0],
                     [1, 1, 1, 1, 1, 0]])
    
    ncols = data.shape[1]
    dtype = data.dtype.descr * ncols
    struct = data.view(dtype)
    
    uniq = np.unique(struct)
    uniq = uniq.view(data.dtype).reshape(-1, ncols)
    print uniq
    

    To understand what's going on, have a look at the intermediary results.

    Once we view things as a structured array, each element in the array is a row in your original array. (Basically, it's a similar data structure to a list of tuples.)

    In [71]: struct
    Out[71]:
    array([[(1, 1, 1, 0, 0, 0)],
           [(0, 1, 1, 1, 0, 0)],
           [(0, 1, 1, 1, 0, 0)],
           [(1, 1, 1, 0, 0, 0)],
           [(1, 1, 1, 1, 1, 0)]],
          dtype=[('f0', '

    Once we run numpy.unique, we'll get a structured array back:

    In [73]: np.unique(struct)
    Out[73]:
    array([(0, 1, 1, 1, 0, 0), (1, 1, 1, 0, 0, 0), (1, 1, 1, 1, 1, 0)],
          dtype=[('f0', '

    That we then need to view as a "normal" array (_ stores the result of the last calculation in ipython, which is why you're seeing _.view...):

    In [74]: _.view(data.dtype)
    Out[74]: array([0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0])
    

    And then reshape back into a 2D array (-1 is a placeholder that tells numpy to calculate the correct number of rows, give the number of columns):

    In [75]: _.reshape(-1, ncols)
    Out[75]:
    array([[0, 1, 1, 1, 0, 0],
           [1, 1, 1, 0, 0, 0],
           [1, 1, 1, 1, 1, 0]])
    

    Obviously, if you wanted to be more concise, you could write it as:

    import numpy as np
    
    def unique_rows(data):
        uniq = np.unique(data.view(data.dtype.descr * data.shape[1]))
        return uniq.view(data.dtype).reshape(-1, data.shape[1])
    
    data = np.array([[1, 1, 1, 0, 0, 0],
                     [0, 1, 1, 1, 0, 0],
                     [0, 1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0, 0],
                     [1, 1, 1, 1, 1, 0]])
    print unique_rows(data)
    

    Which results in:

    [[0 1 1 1 0 0]
     [1 1 1 0 0 0]
     [1 1 1 1 1 0]]
    

提交回复
热议问题