group argmax/argmin over partitioning indices in numpy

后端 未结 2 1966
小蘑菇
小蘑菇 2021-01-14 02:30

Numpy\'s ufuncs have a reduceat method which runs them over contiguous partitions within an array. So instead of writing:

import numpy as np
a =         


        
相关标签:
2条回答
  • 2021-01-14 02:54

    This solution involves building an index over groups ([0, 0, 0, 0, 1, 2, 2, 2, 2, 2] in the above example).

    group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
    n_groups = len(group_lengths)
    index = np.repeat(np.arange(n_groups), group_lengths)
    

    Then we can use:

    maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
    all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
    result = np.empty(len(group_lengths), dtype='i')
    result[index[all_argmax[::-1]]] = all_argmax[::-1]
    

    To get [3, 4, 5] in result. The [::-1]s ensure that we get the first rather than the last argmax in each group.

    This relies on the fact that the last index in fancy assignment determines the value assigned, which @seberg says one shouldn't rely on (and a safer alternative can be achieved with result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]], which involves a sort over len(maxima) ~ n_groups elements).

    0 讨论(0)
  • 2021-01-14 02:59

    Inspired by this question, ive added argmin/max functionality to the numpy_indexed package. Here is what the corresponding test looks like. Note that the keys may be in any order (and of any kind supported by npi):

    def test_argmin():
        keys   = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2]
        values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9]
        unique, amin = group_by(keys).argmin(values)
        npt.assert_equal(unique, [0, 1, 2])
        npt.assert_equal(amin,   [1, 4, 0])
    
    0 讨论(0)
提交回复
热议问题