Can I parallelize `numpy.bincount` using `xarray.apply_ufunc`?

╄→尐↘猪︶ㄣ 提交于 2019-12-23 04:00:14

问题


I want to parallelize the numpy.bincount function using the apply_ufunc API of xarray and the following code is what I've tried:

import numpy as np
import xarray as xr
da = xr.DataArray(np.random.rand(2,16,32),
                  dims=['time', 'y', 'x'],
                  coords={'time': np.array(['2019-04-18', '2019-04-19'],
                                          dtype='datetime64'), 
                         'y': np.arange(16), 'x': np.arange(32)})

f = xr.DataArray(da.data.reshape((2,512)),dims=['time','idx'])
x = da.x.values
y = da.y.values
r = np.sqrt(x[np.newaxis,:]**2 + y[:,np.newaxis]**2)
nbins = 4
if x.max() > y.max():
    ri = np.linspace(0., y.max(), nbins)
else:
    ri = np.linspace(0., x.max(), nbins)

ridx = np.digitize(np.ravel(r), ri)

func = lambda a, b: np.bincount(a, weights=b)
xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)

but I get the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-203-974a8f0a89e8> in <module>()
     12 
     13 func = lambda a, b: np.bincount(a, weights=b)
---> 14 xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)

~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_ufunc(func, *args, **kwargs)
    979                                      signature=signature,
    980                                      join=join,
--> 981                                      exclude_dims=exclude_dims)
    982     elif any(isinstance(a, Variable) for a in args):
    983         return variables_ufunc(*args)

~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_dataarray_ufunc(func, *args, **kwargs)
    208 
    209     data_vars = [getattr(a, 'variable', a) for a in args]
--> 210     result_var = func(*data_vars)
    211 
    212     if signature.num_outputs > 1:

~/anaconda/envs/uptodate/lib/python3.6/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, *args, **kwargs)
    558             raise ValueError('unknown setting for dask array handling in '
    559                              'apply_ufunc: {}'.format(dask))
--> 560     result_data = func(*input_data)
    561 
    562     if signature.num_outputs == 1:

<ipython-input-203-974a8f0a89e8> in <lambda>(a, b)
     11 ridx = np.digitize(np.ravel(r), ri)
     12 
---> 13 func = lambda a, b: np.bincount(a, weights=b)
     14 xr.apply_ufunc(func, xr.DataArray(ridx,dims=['idx']), f)

ValueError: object too deep for desired array

I am kind of lost where the error is stemming from and help would be greatly appreciated...


回答1:


The issue is that apply_along_axis iterates over 1D slices of the first argument to the applied function and not any of the others. If I understand your use-case correctly, you actually want to iterate over 1D slices of the weights (weights in the np.bincount signature), not the integer array (x in the np.bincount signature).

One way to work around this is to write a thin wrapper function around np.bincount that simply switches the order of the arguments:

def wrapped_bincount(weights, x):
    return np.bincount(x, weights=weights)

We can then use np.apply_along_axis with this function for your use-case:

def apply_bincount_along_axis(x, weights, axis=-1):
    return np.apply_along_axis(wrapped_bincount, axis, weights, x)

Finally, we can wrap this new function for use with xarray using apply_ufunc, noting that it can be automatically parallelized with dask (also note that that we do not need to provide an axis argument, because xarray will automatically move the input core dimension dim to the last position in the weights array before applying the function):

def xbincount(x, weights):
    if len(x.dims) != 1:
        raise ValueError('x must be one-dimensional')

    dim, = x.dims
    nbins = x.max() + 1

    return xr.apply_ufunc(apply_bincount_along_axis, x, weights, 
        input_core_dims=[[dim], [dim]],
        output_core_dims=[['bin']], dask='parallelized',
        output_dtypes=[np.float], output_sizes={'bin': nbins})

Applying this function to your example then looks like:

xbincount(ridx, f)

<xarray.DataArray (time: 2, bin: 5)>
array([[  0.      ,   7.934821,  34.066872,  51.118065, 152.769169],
       [  0.      ,  11.692989,  33.262936,  44.993856, 157.642972]])
Dimensions without coordinates: time, bin

As desired it also works with dask arrays:

xbincount(ridx, f.chunk({'time': 1}))

<xarray.DataArray (time: 2, bin: 5)>
dask.array<shape=(2, 5), dtype=float64, chunksize=(1, 5)>
Dimensions without coordinates: time, bin


来源:https://stackoverflow.com/questions/55603803/can-i-parallelize-numpy-bincount-using-xarray-apply-ufunc

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!