Convolution of two three dimensional arrays with padding on one side too slow

前端 未结 3 1039
离开以前
离开以前 2021-02-05 18:56

In my current project I need to \"convolve\" two three dimensional arrays in a slightly unusual way:

Assume we have two three dimensional arrays A and B

3条回答
  •  礼貌的吻别
    2021-02-05 19:25

    The general rule is to use the correct algorithm for the job, which, unless the convolution kernel is short compared to the data, is an FFT based convolution (short roughly means less than log2(n) where n is the length of the data).

    In your case, since you're convolving two equal sizes datasets, you probably want to be considering an FFT based convolution.

    Clearly, scipy.signal.fftconvolve is being a touch deficient in this regard. Using a faster FFT algorithm, you can do much better by rolling your own convolution routine (it doesn't help that fftconvolve forces the transform size to powers of two, otherwise it could be monkey patched).

    The following code uses pyfftw, my wrappers around FFTW and creates a custom convolution class, CustomFFTConvolution:

    class CustomFFTConvolution(object):
    
        def __init__(self, A, B, threads=1):
    
            shape = (np.array(A.shape) + np.array(B.shape))-1
    
            if np.iscomplexobj(A) and np.iscomplexobj(B):
                self.fft_A_obj = pyfftw.builders.fftn(
                        A, s=shape, threads=threads)
                self.fft_B_obj = pyfftw.builders.fftn(
                        B, s=shape, threads=threads)
                self.ifft_obj = pyfftw.builders.ifftn(
                        self.fft_A_obj.get_output_array(), s=shape,
                        threads=threads)
    
            else:
                self.fft_A_obj = pyfftw.builders.rfftn(
                        A, s=shape, threads=threads)
                self.fft_B_obj = pyfftw.builders.rfftn(
                        B, s=shape, threads=threads)
                self.ifft_obj = pyfftw.builders.irfftn(
                        self.fft_A_obj.get_output_array(), s=shape,
                        threads=threads)
    
        def __call__(self, A, B):
    
            fft_padded_A = self.fft_A_obj(A)
            fft_padded_B = self.fft_B_obj(B)
    
            return self.ifft_obj(fft_padded_A * fft_padded_B)
    

    This is used as:

    custom_fft_conv = CustomFFTConvolution(A, B)
    C = custom_fft_conv(A, B) # This can contain different values to during construction
    

    with an optional threads argument when constructing the class. The purpose of creating a class is to benefit from the ability of FFTW to plan the transform in advance.

    The full demo code below simply extends @Kelsey's answer for the timing and so on.

    The speedup is substantial over both the numba solution and the vanilla fftconvolve solution. For n = 33, it's about 40-45x faster than both.

    from timeit import Timer
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.signal import fftconvolve
    from numba import jit, double
    import pyfftw
    
    # Original code
    def custom_convolution(A, B):
    
        dimA = A.shape[0]
        dimB = B.shape[0]
        dimC = dimA + dimB
    
        C = np.zeros((dimC, dimC, dimC))
        for x1 in range(dimA):
            for x2 in range(dimB):
                for y1 in range(dimA):
                    for y2 in range(dimB):
                        for z1 in range(dimA):
                            for z2 in range(dimB):
                                x = x1 + x2
                                y = y1 + y2
                                z = z1 + z2
                                C[x, y, z] += A[x1, y1, z1] * B[x2, y2, z2]
        return C
    
    # Numba'ing the function with the JIT compiler
    numba_convolution = jit(double[:, :, :](double[:, :, :],
                            double[:, :, :]))(custom_convolution)
    
    def fft_convolution(A, B):
        return fftconvolve(A, B, mode='full')
    
    class CustomFFTConvolution(object):
    
        def __init__(self, A, B, threads=1):
    
            shape = (np.array(A.shape) + np.array(B.shape))-1
    
            if np.iscomplexobj(A) and np.iscomplexobj(B):
                self.fft_A_obj = pyfftw.builders.fftn(
                        A, s=shape, threads=threads)
                self.fft_B_obj = pyfftw.builders.fftn(
                        B, s=shape, threads=threads)
                self.ifft_obj = pyfftw.builders.ifftn(
                        self.fft_A_obj.get_output_array(), s=shape,
                        threads=threads)
    
            else:
                self.fft_A_obj = pyfftw.builders.rfftn(
                        A, s=shape, threads=threads)
                self.fft_B_obj = pyfftw.builders.rfftn(
                        B, s=shape, threads=threads)
                self.ifft_obj = pyfftw.builders.irfftn(
                        self.fft_A_obj.get_output_array(), s=shape,
                        threads=threads)
    
        def __call__(self, A, B):
    
            fft_padded_A = self.fft_A_obj(A)
            fft_padded_B = self.fft_B_obj(B)
    
            return self.ifft_obj(fft_padded_A * fft_padded_B)
    
    def run_test():
        reps = 10
        nt, ft, cft, cft2 = [], [], [], []
        x = range(2, 34)
    
        for N in x:
            print N
            A = np.random.rand(N, N, N)
            B = np.random.rand(N, N, N)
    
            custom_fft_conv = CustomFFTConvolution(A, B)
            custom_fft_conv_nthreads = CustomFFTConvolution(A, B, threads=2)
    
            C1 = numba_convolution(A, B)
            C2 = fft_convolution(A, B)
            C3 = custom_fft_conv(A, B)
            C4 = custom_fft_conv_nthreads(A, B)
    
            assert np.allclose(C1[:-1, :-1, :-1], C2)
            assert np.allclose(C1[:-1, :-1, :-1], C3)
            assert np.allclose(C1[:-1, :-1, :-1], C4)
    
            t = Timer(lambda: numba_convolution(A, B))
            nt.append(t.timeit(number=reps))
            t = Timer(lambda: fft_convolution(A, B))
            ft.append(t.timeit(number=reps))
            t = Timer(lambda: custom_fft_conv(A, B))
            cft.append(t.timeit(number=reps))
            t = Timer(lambda: custom_fft_conv_nthreads(A, B))
            cft2.append(t.timeit(number=reps))
    
        plt.plot(x, ft, label='scipy.signal.fftconvolve')
        plt.plot(x, nt, label='custom numba convolve')
        plt.plot(x, cft, label='custom pyfftw convolve')
        plt.plot(x, cft2, label='custom pyfftw convolve with threading')        
        plt.legend()
        plt.show()
    
    if __name__ == '__main__':
        run_test()
    

    EDIT: More recent scipy does a better job of not always padding to powers of 2 length so is closer in output to the pyFFTW case.

提交回复
热议问题