Keras/Tensorflow - fourier pointwise multiplication implementation of conv2d running 4x slower than spatial convolution

自作多情 提交于 2020-08-26 13:48:38

问题


According to the convolution theorem, convolution changes to pointwise multiplication in the fourier domain, and the overheads of taking the fourier transform have been shown to be overshadowed by the gain due to conversion of convolution operation to pointwise multiplication operation in many previous works like the following - https://arxiv.org/abs/1312.5851.

To replicate this, I was trying to replace the keras.layers.Conv2D() layer by a custom layer that accepts the rfft of input data (I took the rfft of data before feeding it into the model to reduce training time), initialises 'no_of_kernels' number of kernels of the same size as the image, takes its rfft, multiplies the input and kernel pointwise and returns the product (yes, without taking irfft since I want to further train the network in fourier domain itself) -

In the layer, the call function is implemented as follows - Note - in my dataset, i.e. MNIST image height = width, so the transpose works fine

def call(self, x):
        fft_x = x #(batch_size, height, width, in_channels)
        fft_kernel = tf.spectral.rfft2d(self.kernel) #(in_channels, height, width, out_channels)
        fft_kernel = tf.transpose(fft_kernel, perm=[2, 1, 0, 3]) #(width, height, in_channels, out_channels)
        output  = tf.einsum('ijkl,jklo->ijko', fft_x, fft_kernel)
        return output 

This code preserves accuracy given by Keras Conv2D layer, but it runs around 4 times slower than Conv2D, thus the purpose of transforming into fourier domain is defeated. Could anyone please clarify why this happens and how can I replicate the results of fast convolutions in the fourier domain?

(Note - For anyone who might feel tf.spectral.rfft2d(self.kernel) might be the overhead, it is not the case as I have verified.

Also, I think the Conv2D function might be flattening the 4D input tensors and kernels to reduce it to matrix multiplication as explained here - enter link description here. I could not think of any intelligent method of flattening etc. to perform pointwise multiplication except for viewing it as dot product as I have done with tf.einsum. Is there any intelligent method to do pointwise multiplication? ) Thanks.

Edit - Entire implementation of the layer for reference-

class Fourier_Conv2D(Layer):
    def __init__(self, no_of_kernels, **kwargs):
        self.no_of_kernels = no_of_kernels
        super(Fourier_Conv2D, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel_shape = (int(input_shape[3]), int(input_shape[1]), int(input_shape[2]), self.no_of_kernels)
        self.kernel = self.add_weight(name = 'kernel', 
                                      shape = self.kernel_shape, 
                                      initializer = 'uniform', trainable = True)
        super(Fourier_Conv2D, self).build(input_shape)

    def call(self, x):
        fft_x = x
        fft_kernel = tf.spectral.rfft2d(self.kernel)
        fft_kernel = tf.transpose(fft_kernel, perm=[2, 1, 0, 3])
        output  = tf.einsum('ijkl,jklo->ijko', fft_x, fft_kernel)
        return output       

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2], int(self.no_of_kernels/2)+1)

回答1:


I don't think your result is surprising at all, the implementation of Conv2D in Keras is left to the backend, and most backends (like TensorFlow) have very optimized versions of the convolution operations, specially if you use CuDNN. So your own version, which should be faster than a naive implementation, is slower than a highly optimized one.

Its possible that in order to make a meaningful comparison, you will have to implement a baseline Conv2D that does convolution in a naive way, without any kind of optimizations.



来源:https://stackoverflow.com/questions/55453119/keras-tensorflow-fourier-pointwise-multiplication-implementation-of-conv2d-run

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!