Given a structured numpy array, I want to remove certain columns by name without copying the array. I know I can do this:
names = list(a.dtype.names)
if name
You can create a new data type containing just the fields that you want, with the same field offsets and the same itemsize as the original array's data type, and then use this new data type to create a view of the original array. The dtype
function handles arguments with many formats; the relevant one is described in the section of the documentation called "Specifying and constructing data types". Scroll down to the subsection that begins with
{'names': ..., 'formats': ..., 'offsets': ..., 'titles': ..., 'itemsize': ...}
Here are a couple convenience functions that use this idea.
import numpy as np
def view_fields(a, names):
"""
`a` must be a numpy structured array.
`names` is the collection of field names to keep.
Returns a view of the array `a` (not a copy).
"""
dt = a.dtype
formats = [dt.fields[name][0] for name in names]
offsets = [dt.fields[name][1] for name in names]
itemsize = a.dtype.itemsize
newdt = np.dtype(dict(names=names,
formats=formats,
offsets=offsets,
itemsize=itemsize))
b = a.view(newdt)
return b
def remove_fields(a, names):
"""
`a` must be a numpy structured array.
`names` is the collection of field names to remove.
Returns a view of the array `a` (not a copy).
"""
dt = a.dtype
keep_names = [name for name in dt.names if name not in names]
return view_fields(a, keep_names)
For example,
In [297]: a
Out[297]:
array([(10.0, 13.5, 1248, -2), (20.0, 0.0, 0, 0), (30.0, 0.0, 0, 0),
(40.0, 0.0, 0, 0), (50.0, 0.0, 0, 999)],
dtype=[('x', '<f8'), ('y', '<f8'), ('i', '<i8'), ('j', '<i8')])
In [298]: b = remove_fields(a, ['i', 'j'])
In [299]: b
Out[299]:
array([(10.0, 13.5), (20.0, 0.0), (30.0, 0.0), (40.0, 0.0), (50.0, 0.0)],
dtype={'names':['x','y'], 'formats':['<f8','<f8'], 'offsets':[0,8], 'itemsize':32})
Verify that b
is a view (not a copy) of a
by changing b[0]['x']
...
In [300]: b[0]['x'] = 3.14
and seeing that a
is also changed:
In [301]: a[0]
Out[301]: (3.14, 13.5, 1248, -2)