问题
I want to parallelize the numpy.bincount
function using the apply_ufunc
API of xarray
and the following code is what I've tried:
import numpy as np
import xarray as xr
da = xr.DataArray(np.random.rand(2,16,32),
dims=['time', 'y', 'x'],
coords={'time': np.array(['2019-04-18', '2019-04-19'],
dtype='datetime64'),
'y': np.arange(16), 'x': np.arange(32)})
f = xr.DataArray(da.data.reshape((2,512)),dims=['time','idx'])
x = da.x.values
y = da.y.values
r = np.sqrt(x[np.newaxis,:]**2 + y[:,np.newaxis]**2)
nbins = 4
if x.max() > y.max():
ri = np.linspace(0., y.max(), nbins)
else:
ri = np.linspace(0., x.max(), nbins)
ridx = np.digitize(np.ravel(r), ri)
func = lambda a, b: np.bincount(a, weights=b)
xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)
but I get the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-203-974a8f0a89e8> in <module>()
12
13 func = lambda a, b: np.bincount(a, weights=b)
---> 14 xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)
~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_ufunc(func, *args, **kwargs)
979 signature=signature,
980 join=join,
--> 981 exclude_dims=exclude_dims)
982 elif any(isinstance(a, Variable) for a in args):
983 return variables_ufunc(*args)
~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_dataarray_ufunc(func, *args, **kwargs)
208
209 data_vars = [getattr(a, 'variable', a) for a in args]
--> 210 result_var = func(*data_vars)
211
212 if signature.num_outputs > 1:
~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, *args, **kwargs)
558 raise ValueError('unknown setting for dask array handling in '
559 'apply_ufunc: {}'.format(dask))
--> 560 result_data = func(*input_data)
561
562 if signature.num_outputs == 1:
<ipython-input-203-974a8f0a89e8> in <lambda>(a, b)
11 ridx = np.digitize(np.ravel(r), ri)
12
---> 13 func = lambda a, b: np.bincount(a, weights=b)
14 xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)
ValueError: object too deep for desired array
I am kind of lost where the error is stemming from and help would be greatly appreciated...
回答1:
The issue is that apply_along_axis
iterates over 1D slices of the first argument to the applied function and not any of the others. If I understand your use-case correctly, you actually want to iterate over 1D slices of the weights (weights in the np.bincount signature), not the integer array (x
in the np.bincount
signature).
One way to work around this is to write a thin wrapper function around np.bincount
that simply switches the order of the arguments:
def wrapped_bincount(weights, x):
return np.bincount(x, weights=weights)
We can then use np.apply_along_axis
with this function for your use-case:
def apply_bincount_along_axis(x, weights, axis=-1):
return np.apply_along_axis(wrapped_bincount, axis, weights, x)
Finally, we can wrap this new function for use with xarray using apply_ufunc, noting that it can be automatically parallelized with dask (also note that that we do not need to provide an axis
argument, because xarray will automatically move the input core dimension dim
to the last position in the weights
array before applying the function):
def xbincount(x, weights):
if len(x.dims) != 1:
raise ValueError('x must be one-dimensional')
dim, = x.dims
nbins = x.max() + 1
return xr.apply_ufunc(apply_bincount_along_axis, x, weights,
input_core_dims=[[dim], [dim]],
output_core_dims=[['bin']], dask='parallelized',
output_dtypes=[np.float], output_sizes={'bin': nbins})
Applying this function to your example then looks like:
xbincount(ridx, f)
<xarray.DataArray (time: 2, bin: 5)>
array([[ 0. , 7.934821, 34.066872, 51.118065, 152.769169],
[ 0. , 11.692989, 33.262936, 44.993856, 157.642972]])
Dimensions without coordinates: time, bin
As desired it also works with dask arrays:
xbincount(ridx, f.chunk({'time': 1}))
<xarray.DataArray (time: 2, bin: 5)>
dask.array<shape=(2, 5), dtype=float64, chunksize=(1, 5)>
Dimensions without coordinates: time, bin
来源:https://stackoverflow.com/questions/55603803/can-i-parallelize-numpy-bincount-using-xarray-apply-ufunc