How to remove a column from a structured numpy array *without copying it*?

后端 未结 1 978
一个人的身影
一个人的身影 2021-01-18 22:14

Given a structured numpy array, I want to remove certain columns by name without copying the array. I know I can do this:

names = list(a.dtype.names)
if name         


        
相关标签:
1条回答
  • 2021-01-18 22:38

    You can create a new data type containing just the fields that you want, with the same field offsets and the same itemsize as the original array's data type, and then use this new data type to create a view of the original array. The dtype function handles arguments with many formats; the relevant one is described in the section of the documentation called "Specifying and constructing data types". Scroll down to the subsection that begins with

    {'names': ..., 'formats': ..., 'offsets': ..., 'titles': ..., 'itemsize': ...}
    

    Here are a couple convenience functions that use this idea.

    import numpy as np
    
    
    def view_fields(a, names):
        """
        `a` must be a numpy structured array.
        `names` is the collection of field names to keep.
    
        Returns a view of the array `a` (not a copy).
        """
        dt = a.dtype
        formats = [dt.fields[name][0] for name in names]
        offsets = [dt.fields[name][1] for name in names]
        itemsize = a.dtype.itemsize
        newdt = np.dtype(dict(names=names,
                              formats=formats,
                              offsets=offsets,
                              itemsize=itemsize))
        b = a.view(newdt)
        return b
    
    
    def remove_fields(a, names):
        """
        `a` must be a numpy structured array.
        `names` is the collection of field names to remove.
    
        Returns a view of the array `a` (not a copy).
        """
        dt = a.dtype
        keep_names = [name for name in dt.names if name not in names]
        return view_fields(a, keep_names)
    

    For example,

    In [297]: a
    Out[297]: 
    array([(10.0, 13.5, 1248, -2), (20.0, 0.0, 0, 0), (30.0, 0.0, 0, 0),
           (40.0, 0.0, 0, 0), (50.0, 0.0, 0, 999)], 
          dtype=[('x', '<f8'), ('y', '<f8'), ('i', '<i8'), ('j', '<i8')])
    
    In [298]: b = remove_fields(a, ['i', 'j'])
    
    In [299]: b
    Out[299]: 
    array([(10.0, 13.5), (20.0, 0.0), (30.0, 0.0), (40.0, 0.0), (50.0, 0.0)], 
          dtype={'names':['x','y'], 'formats':['<f8','<f8'], 'offsets':[0,8], 'itemsize':32})
    

    Verify that b is a view (not a copy) of a by changing b[0]['x']...

    In [300]: b[0]['x'] = 3.14
    

    and seeing that a is also changed:

    In [301]: a[0]
    Out[301]: (3.14, 13.5, 1248, -2)
    
    0 讨论(0)
提交回复
热议问题