Memory consumption of NumPy function for standard deviation

后端 未结 2 1395
误落风尘
误落风尘 2021-01-11 11:39

I\'m currently using the Python bindings of GDAL to work on quite large raster data sets (> 4 GB). Since loading them into memory at once is no feasible solution for me I re

2条回答
  •  北荒
    北荒 (楼主)
    2021-01-11 12:10

    Cython to the rescue! This achieves a nice speed up:

    %%cython
    cimport cython
    cimport numpy as np
    from libc.math cimport sqrt
    
    @cython.boundscheck(False)
    def std_welford(np.ndarray[np.float64_t, ndim=1] a):
        cdef int n = 0
        cdef float mean = 0
        cdef float M2 = 0
        cdef int a_len = len(a)
        cdef int i
        cdef float delta
        cdef float result
        for i in range(a_len):
            n += 1
            delta = a[i] - mean
            mean += delta / n
            M2 += delta * (a[i] - mean)
        if n < 2:
            result = np.nan
            return result
        else:
            result = sqrt(M2 / (n - 1))
            return result
    

    Using this to test:

    a = np.random.rand(10000).astype(np.float)
    print std_welford(a)
    %timeit -n 10 -r 10 std_welford(a)
    

    Cython code

    0.288327455521
    10 loops, best of 10: 59.6 µs per loop
    

    Original code

    0.289605617397
    10 loops, best of 10: 18.5 ms per loop
    

    Numpy std

    0.289493223504
    10 loops, best of 10: 29.3 µs per loop
    

    So a speed increase of around 300x. Still not as good as the numpy version..

提交回复
热议问题