Grouping by multiple dimensions

后端 未结 2 548
花落未央
花落未央 2021-01-12 20:01

Grouping by a single dimension works fine for xarray DataArrays:

d = xr.DataArray([1, 2, 3], coords={\'a\': [\'x\', \'x\', \'y\']}, dims=[\'a\'])
d.groupby(\         


        
2条回答
  •  孤街浪徒
    2021-01-12 20:27

    I built a manual solution. To make it efficient, I discard all of xarray and rebuild indices and values by hand. Any change to use more xarray (e.g. using sel, re-packaging cells into a DataArray; also see https://github.com/pydata/xarray/issues/2452) led to serious losses in speed.

    import itertools
    from collections import defaultdict
    
    import numpy as np
    import xarray as xr
    from xarray import DataArray
    
    class DataAssembly(DataArray):
        def multi_dim_groupby(self, groups, apply):
            # align
            groups = sorted(groups, key=lambda group: self.dims.index(self[group].dims[0]))
            # build indices
            groups = {group: np.unique(self[group]) for group in groups}
            group_dims = {self[group].dims: group for group in groups}
            indices = defaultdict(lambda: defaultdict(list))
            result_indices = defaultdict(dict)
            for group in groups:
                for index, value in enumerate(self[group].values):
                    indices[group][value].append(index)
                    if value not in result_indices[group]:  # if captured once, it will be "grouped away"
                        index = max(result_indices[group].values()) + 1 if len(result_indices[group]) > 0 else 0
                        result_indices[group][value] = index
    
            coords = {coord: (dims, value) for coord, dims, value in walk_coords(self)}
    
            def simplify(value):
                return value.item() if value.size == 1 else value
    
            def indexify(dict_indices):
                return [(i,) if isinstance(i, int) else tuple(i) for i in dict_indices.values()]
    
            # group and apply
            # making this a DataArray right away and then inserting through .loc would slow things down
            result = np.zeros([len(indices) for indices in result_indices.values()])
            result_coords = {coord: (dims, [None] * len(result_indices[group_dims[dims]]))
                             for coord, (dims, value) in coords.items()}
            for values in itertools.product(*groups.values()):
                group_values = dict(zip(groups.keys(), values))
                self_indices = {group: indices[group][value] for group, value in group_values.items()}
                values_indices = indexify(self_indices)
                cells = self.values[values_indices]  # using DataArray would slow things down. thus we pass coords as kwargs
                cells = simplify(cells)
                cell_coords = {coord: (dims, value[self_indices[group_dims[dims]]])
                               for coord, (dims, value) in coords.items()}
                cell_coords = {coord: (dims, simplify(np.unique(value))) for coord, (dims, value) in cell_coords.items()}
    
                # ignore dims when passing to function
                passed_coords = {coord: value for coord, (dims, value) in cell_coords.items()}
                merge = apply(cells, **passed_coords)
                result_idx = {group: result_indices[group][value] for group, value in group_values.items()}
                result[indexify(result_idx)] = merge
                for coord, (dims, value) in cell_coords.items():
                    if isinstance(value, np.ndarray):  # multiple values for coord -> ignore
                        if coord in result_coords:  # delete from result coords if not yet deleted
                            del result_coords[coord]
                        continue
                    assert dims == result_coords[coord][0]
                    coord_index = result_idx[group_dims[dims]]
                    result_coords[coord][1][coord_index] = value
    
            # re-package
            result = type(self)(result, coords=result_coords, dims=list(itertools.chain(*group_dims.keys())))
            return result
    
    def walk_coords(assembly):
        """
        walks through coords and all levels, just like the `__repr__` function, yielding `(name, dims, values)`.
        """
        coords = {}
    
        for name, values in assembly.coords.items():
            # partly borrowed from xarray.core.formatting#summarize_coord
            is_index = name in assembly.dims
            if is_index and values.variable.level_names:
                for level in values.variable.level_names:
                    level_values = assembly.coords[level]
                    yield level, level_values.dims, level_values.values
            else:
                yield name, values.dims, values.values
        return coords
    

    The method multi_dim_groupby performs grouping and apply in one step. The passed apply method can accept group coords via parameters named after the coords (or ignore the coords by putting **_ in the function header).

    It's not particularly pretty and does not cover all possible cases but at least covers the following test cases:

    import DataAssembly
    
    class TestMultiDimGroupby:
        def test_unique_values(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ['a', 'b', 'c', 'd'],
                                     'b': ['x', 'y', 'z']},
                             dims=['a', 'b'])
            g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x)
            assert g.equals(d)
    
        def test_nonunique_singledim(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ['a', 'a', 'b', 'b'],
                                     'b': ['x', 'y', 'z']},
                             dims=['a', 'b'])
            g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x.mean())
            assert g.equals(DataAssembly([[2.5, 3.5, 4.5], [8.5, 9.5, 10.5]],
                                         coords={'a': ['a', 'b'], 'b': ['x', 'y', 'z']},
                                         dims=['a', 'b']))
    
        def test_nonunique_adjacentcoord(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ('adim', ['a', 'a', 'b', 'b']),
                                     'aa': ('adim', ['a', 'b', 'a', 'b']),
                                     'b': ['x', 'y', 'z']},
                             dims=['adim', 'b'])
            g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x.mean())
            assert g.equals(DataAssembly([[2.5, 3.5, 4.5], [8.5, 9.5, 10.5]],
                                         coords={'adim': ['a', 'b'], 'b': ['x', 'y', 'z']},
                                         dims=['adim', 'b'])), \
                "adjacent coord aa should be discarded due to non-mappability"
    
        def test_unique_values_swappeddims(self):
            d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                             coords={'a': ['a', 'b', 'c', 'd'],
                                     'b': ['x', 'y', 'z']},
                             dims=['a', 'b'])
            g = d.multi_dim_groupby(['b', 'a'], lambda x, **_: x)
            assert g.equals(d)
    

提交回复
热议问题