Grouping by multiple dimensions

后端 未结 2 550
花落未央
花落未央 2021-01-12 20:01

Grouping by a single dimension works fine for xarray DataArrays:

d = xr.DataArray([1, 2, 3], coords={\'a\': [\'x\', \'x\', \'y\']}, dims=[\'a\'])
d.groupby(\         


        
相关标签:
2条回答
  • 2021-01-12 20:26

    I don't know how it would compare speed-wise, nor do I have enough time to put together a full solution for this particular question instance, but I found this question and answer quite helpful when I was searching for ways to iterate over multiple dimensions in xarray and wanted to share the approach I ended up taking. I ultimately used dimension stacking based on this example code by @RyanAbernathy:

    import xarray as xr
    import numpy as np
    
    # create an example dataset
    da = xr.DataArray(np.random.rand(10,30,40), dims=['dtime', 'x', 'y'])
    
    # define a function to compute a linear trend of a timeseries
    def linear_trend(x):
        pf = np.polyfit(x.time, x, 1)
        # we need to return a dataarray or else xarray's groupby won't be happy
        return xr.DataArray(pf[0])
    
    # stack lat and lon into a single dimension called allpoints
    stacked = da.stack(allpoints=['x','y'])
    # apply the function over allpoints to calculate the trend at each point
    trend = stacked.groupby('allpoints').apply(linear_trend)
    # unstack back to lat lon coordinates
    trend_unstacked = trend.unstack('allpoints')
    

    in combination with some groupby wrappers to compute multiple groupbys:

    def _calc_allpoints(ds, function):
            """
            Helper function to do a pixel-wise calculation that requires using x and y dimension values
            as inputs. This version does the computation over all available timesteps as well.
    
            """
    
            # note: the below code will need to be generalized for other dimensions
    
            def _time_wrapper(gb):
                gb = gb.groupby('dtime', squeeze=False).apply(function)
                return gb
            
            # stack x and y into a single dimension called allpoints
            stacked = ds.stack(allpoints=['x','y'])
            # groupby time and apply the function over allpoints to calculate the trend at each point
            newelev = stacked.groupby('allpoints', squeeze=False).apply(_time_wrapper)
            # unstack back to x y coordinates
            ds = newelev.unstack('allpoints')
    
            return ds
    

    where function is whatever function you are using (e.g. linear_trend)

    0 讨论(0)
  • 2021-01-12 20:27

    I built a manual solution. To make it efficient, I discard all of xarray and rebuild indices and values by hand. Any change to use more xarray (e.g. using sel, re-packaging cells into a DataArray; also see https://github.com/pydata/xarray/issues/2452) led to serious losses in speed.

    import itertools
    from collections import defaultdict
    
    import numpy as np
    import xarray as xr
    from xarray import DataArray
    
    class DataAssembly(DataArray):
        def multi_dim_groupby(self, groups, apply):
            # align
            groups = sorted(groups, key=lambda group: self.dims.index(self[group].dims[0]))
            # build indices
            groups = {group: np.unique(self[group]) for group in groups}
            group_dims = {self[group].dims: group for group in groups}
            indices = defaultdict(lambda: defaultdict(list))
            result_indices = defaultdict(dict)
            for group in groups:
                for index, value in enumerate(self[group].values):
                    indices[group][value].append(index)
                    if value not in result_indices[group]:  # if captured once, it will be "grouped away"
                        index = max(result_indices[group].values()) + 1 if len(result_indices[group]) > 0 else 0
                        result_indices[group][value] = index
    
            coords = {coord: (dims, value) for coord, dims, value in walk_coords(self)}
    
            def simplify(value):
                return value.item() if value.size == 1 else value
    
            def indexify(dict_indices):
                return [(i,) if isinstance(i, int) else tuple(i) for i in dict_indices.values()]
    
            # group and apply
            # making this a DataArray right away and then inserting through .loc would slow things down
            result = np.zeros([len(indices) for indices in result_indices.values()])
            result_coords = {coord: (dims, [None] * len(result_indices[group_dims[dims]]))
                             for coord, (dims, value) in coords.items()}
            for values in itertools.product(*groups.values()):
                group_values = dict(zip(groups.keys(), values))
                self_indices = {group: indices[group][value] for group, value in group_values.items()}
                values_indices = indexify(self_indices)
                cells = self.values[values_indices]  # using DataArray would slow things down. thus we pass coords as kwargs
                cells = simplify(cells)
                cell_coords = {coord: (dims, value[self_indices[group_dims[dims]]])
                               for coord, (dims, value) in coords.items()}
                cell_coords = {coord: (dims, simplify(np.unique(value))) for coord, (dims, value) in cell_coords.items()}
    
                # ignore dims when passing to function
                passed_coords = {coord: value for coord, (dims, value) in cell_coords.items()}
                merge = apply(cells, **passed_coords)
                result_idx = {group: result_indices[group][value] for group, value in group_values.items()}
                result[indexify(result_idx)] = merge
                for coord, (dims, value) in cell_coords.items():
                    if isinstance(value, np.ndarray):  # multiple values for coord -> ignore
                        if coord in result_coords:  # delete from result coords if not yet deleted
                            del result_coords[coord]
                        continue
                    assert dims == result_coords[coord][0]
                    coord_index = result_idx[group_dims[dims]]
                    result_coords[coord][1][coord_index] = value
    
            # re-package
            result = type(self)(result, coords=result_coords, dims=list(itertools.chain(*group_dims.keys())))
            return result
    
    def walk_coords(assembly):
        """
        walks through coords and all levels, just like the `__repr__` function, yielding `(name, dims, values)`.
        """
        coords = {}
    
        for name, values in assembly.coords.items():
            # partly borrowed from xarray.core.formatting#summarize_coord
            is_index = name in assembly.dims
            if is_index and values.variable.level_names:
                for level in values.variable.level_names:
                    level_values = assembly.coords[level]
                    yield level, level_values.dims, level_values.values
            else:
                yield name, values.dims, values.values
        return coords
    

    The method multi_dim_groupby performs grouping and apply in one step. The passed apply method can accept group coords via parameters named after the coords (or ignore the coords by putting **_ in the function header).

    It's not particularly pretty and does not cover all possible cases but at least covers the following test cases:

    import DataAssembly
    
    class TestMultiDimGroupby:
        def test_unique_values(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ['a', 'b', 'c', 'd'],
                                     'b': ['x', 'y', 'z']},
                             dims=['a', 'b'])
            g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x)
            assert g.equals(d)
    
        def test_nonunique_singledim(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ['a', 'a', 'b', 'b'],
                                     'b': ['x', 'y', 'z']},
                             dims=['a', 'b'])
            g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x.mean())
            assert g.equals(DataAssembly([[2.5, 3.5, 4.5], [8.5, 9.5, 10.5]],
                                         coords={'a': ['a', 'b'], 'b': ['x', 'y', 'z']},
                                         dims=['a', 'b']))
    
        def test_nonunique_adjacentcoord(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ('adim', ['a', 'a', 'b', 'b']),
                                     'aa': ('adim', ['a', 'b', 'a', 'b']),
                                     'b': ['x', 'y', 'z']},
                             dims=['adim', 'b'])
            g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x.mean())
            assert g.equals(DataAssembly([[2.5, 3.5, 4.5], [8.5, 9.5, 10.5]],
                                         coords={'adim': ['a', 'b'], 'b': ['x', 'y', 'z']},
                                         dims=['adim', 'b'])), \
                "adjacent coord aa should be discarded due to non-mappability"
    
        def test_unique_values_swappeddims(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ['a', 'b', 'c', 'd'],
                                     'b': ['x', 'y', 'z']},
                             dims=['a', 'b'])
            g = d.multi_dim_groupby(['b', 'a'], lambda x, **_: x)
            assert g.equals(d)
    
    0 讨论(0)
提交回复
热议问题