问题
I want to parallelize the numpy.bincount function using the apply_ufunc API of xarray and the following code is what I've tried:
import numpy as np
import xarray as xr
da = xr.DataArray(np.random.rand(2,16,32),
dims=['time', 'y', 'x'],
coords={'time': np.array(['2019-04-18', '2019-04-19'],
dtype='datetime64'),
'y': np.arange(16), 'x': np.arange(32)})
f = xr.DataArray(da.data.reshape((2,512)),dims=['time','idx'])
x = da.x.values
y = da.y.values
r = np.sqrt(x[np.newaxis,:]**2 + y[:,np.newaxis]**2)
nbins = 4
if x.max() > y.max():
ri = np.linspace(0., y.max(), nbins)
else:
ri = np.linspace(0., x.max(), nbins)
ridx = np.digitize(np.ravel(r), ri)
func = lambda a, b: np.bincount(a, weights=b)
xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)
but I get the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-203-974a8f0a89e8> in <module>()
12
13 func = lambda a, b: np.bincount(a, weights=b)
---> 14 xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)
~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_ufunc(func, *args, **kwargs)
979 signature=signature,
980 join=join,
--> 981 exclude_dims=exclude_dims)
982 elif any(isinstance(a, Variable) for a in args):
983 return variables_ufunc(*args)
~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_dataarray_ufunc(func, *args, **kwargs)
208
209 data_vars = [getattr(a, 'variable', a) for a in args]
--> 210 result_var = func(*data_vars)
211
212 if signature.num_outputs > 1:
~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, *args, **kwargs)
558 raise ValueError('unknown setting for dask array handling in '
559 'apply_ufunc: {}'.format(dask))
--> 560 result_data = func(*input_data)
561
562 if signature.num_outputs == 1:
<ipython-input-203-974a8f0a89e8> in <lambda>(a, b)
11 ridx = np.digitize(np.ravel(r), ri)
12
---> 13 func = lambda a, b: np.bincount(a, weights=b)
14 xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)
ValueError: object too deep for desired array
I am kind of lost where the error is stemming from and help would be greatly appreciated...
回答1:
The issue is that apply_along_axis iterates over 1D slices of the first argument to the applied function and not any of the others. If I understand your use-case correctly, you actually want to iterate over 1D slices of the weights (weights in the np.bincount signature), not the integer array (x in the np.bincount signature).
One way to work around this is to write a thin wrapper function around np.bincount that simply switches the order of the arguments:
def wrapped_bincount(weights, x):
return np.bincount(x, weights=weights)
We can then use np.apply_along_axis with this function for your use-case:
def apply_bincount_along_axis(x, weights, axis=-1):
return np.apply_along_axis(wrapped_bincount, axis, weights, x)
Finally, we can wrap this new function for use with xarray using apply_ufunc, noting that it can be automatically parallelized with dask (also note that that we do not need to provide an axis argument, because xarray will automatically move the input core dimension dim to the last position in the weights array before applying the function):
def xbincount(x, weights):
if len(x.dims) != 1:
raise ValueError('x must be one-dimensional')
dim, = x.dims
nbins = x.max() + 1
return xr.apply_ufunc(apply_bincount_along_axis, x, weights,
input_core_dims=[[dim], [dim]],
output_core_dims=[['bin']], dask='parallelized',
output_dtypes=[np.float], output_sizes={'bin': nbins})
Applying this function to your example then looks like:
xbincount(ridx, f)
<xarray.DataArray (time: 2, bin: 5)>
array([[ 0. , 7.934821, 34.066872, 51.118065, 152.769169],
[ 0. , 11.692989, 33.262936, 44.993856, 157.642972]])
Dimensions without coordinates: time, bin
As desired it also works with dask arrays:
xbincount(ridx, f.chunk({'time': 1}))
<xarray.DataArray (time: 2, bin: 5)>
dask.array<shape=(2, 5), dtype=float64, chunksize=(1, 5)>
Dimensions without coordinates: time, bin
来源:https://stackoverflow.com/questions/55603803/can-i-parallelize-numpy-bincount-using-xarray-apply-ufunc