Extend numpy mask by n cells to the right for each bad value, efficiently

前端 未结 7 1349
[愿得一人]
[愿得一人] 2021-02-15 15:39

Let\'s say I have a length 30 array with 4 bad values in it. I want to create a mask for those bad values, but since I will be using rolling window functions, I\'d also like a f

相关标签:
7条回答
  • 2021-02-15 16:07

    Yet another answer!
    It just takes the mask you already have and applies logical or to shifted versions of itself. Nicely vectorized and insanely fast! :D

    def repeat_or(a, n=4):
        m = np.isnan(a)
        k = m.copy()
    
        # lenM and lenK say for each mask how many
        # subsequent Trues there are at least
        lenM, lenK = 1, 1
    
        # we run until a combination of both masks will give us n or more
        # subsequent Trues
        while lenM+lenK < n:
            # append what we have in k to the end of what we have in m
            m[lenM:] |= k[:-lenM]
    
            # swap so that m is again the small one
            m, k = k, m
    
            # update the lengths
            lenM, lenK = lenK, lenM+lenK
    
        # see how much m has to be shifted in order to append the missing Trues
        k[n-lenM:] |= m[:-n+lenM]
    
        return k
    

    Unfortunately I couldn't get m[i:] |= m[:-i] running... probably a bad idea to both modify and use the mask to modify itself. It does work for m[:-i] |= m[i:], however this is the wrong direction.
    Anyway, instead of quadratic growth we now have Fibonacci-like growth which is still better than linear.
    (I never thought I'd actually write an algorithm that is really related to the Fibonacci sequence without being some weird math problem.)

    Testing under "real" conditions with array of size 1e6 and 1e5 NANs:

    In [5]: a = np.random.random(size=1e6)
    
    In [6]: a[np.random.choice(np.arange(len(a), dtype=int), 1e5, replace=False)] = np.nan
    
    In [7]: %timeit reduceat(a)
    10 loops, best of 3: 65.2 ms per loop
    
    In [8]: %timeit index_expansion(a)
    100 loops, best of 3: 12 ms per loop
    
    In [9]: %timeit cumsum_trick(a)
    10 loops, best of 3: 17 ms per loop
    
    In [10]: %timeit repeat_or(a)
    1000 loops, best of 3: 1.9 ms per loop
    
    In [11]: %timeit agml_indexing(a)
    100 loops, best of 3: 6.91 ms per loop
    

    I'll leave further benchmarks to Thomas.

    0 讨论(0)
提交回复
热议问题