The documentation for the conv2d_transpose()
operation does not clearly explain what it does:
The transpose of conv2d.
This opera
Here's another viewpoint from the "gradients" perspective, i.e. why TensorFlow documentation says conv2d_transpose()
is "actually the transpose (gradient) of conv2d rather than an actual deconvolution". For more details on the actual computation done in conv2d_transpose
, I would highly recommend this article, starting from page 19.
In tf.nn
, there are 4 closely related and rather confusing functions for 2d convolution:
tf.nn.conv2d
tf.nn.conv2d_backprop_filter
tf.nn.conv2d_backprop_input
tf.nn.conv2d_transpose
One sentence summary: they are all just 2d convolutions. Their differences are in their input arguments ordering, input rotation or transpose, strides (including fractional stride size), paddings and etc. With tf.nn.conv2d
in hand, one can implement all of the 3 other ops by transforming inputs and changing the conv2d
arguments.
# forward
out = conv2d(x, w)
# backward, given d_out
=> find d_x?
=> find d_w?
In the forward computation, we compute the convolution of input image x
with the filter w
, and the result is out
.
In the backward computation, assume we're given d_out
, which is the gradient w.r.t. out
. Our goal is to find d_x
and d_w
, which are the gradient w.r.t. x
and w
respectively.
For the ease of discussion, we assume:
1
in_channels
and out_channels
are 1
VALID
paddingConceptually, with the assumptions above, we have the following relations:
out = conv2d(x, w, padding='VALID')
d_x = conv2d(d_out, rot180(w), padding='FULL')
d_w = conv2d(x, d_out, padding='VALID')
Where rot180
is a 2d matrix rotated 180 degrees (a left-right flip and a top-down flip), FULL
means "apply filter wherever it partly overlaps with the input" (see theano docs). Notes that this is only valid with the above assumptions, however, one can change the conv2d arguments to generalize it.
The key takeaways:
d_x
is the convolution of the output gradient d_out
and the weight w
, with some modifications.d_w
is the convolution of the input x
and the output gradient d_out
, with some modifications.Now, let's give an actual working code example of how to use the 4 functions above to compute d_x
and d_w
given d_out
. This shows how
conv2d
,
conv2d_backprop_filter
,
conv2d_backprop_input
, and
conv2d_transpose
are related to each other.
Please find the full scripts here.
Computing d_x
in 4 different ways:
# Method 1: TF's autodiff
d_x = tf.gradients(f, x)[0]
# Method 2: manually using conv2d
d_x_manual = tf.nn.conv2d(input=tf_pad_to_full_conv2d(d_out, w_size),
filter=tf_rot180(w),
strides=strides,
padding='VALID')
# Method 3: conv2d_backprop_input
d_x_backprop_input = tf.nn.conv2d_backprop_input(input_sizes=x_shape,
filter=w,
out_backprop=d_out,
strides=strides,
padding='VALID')
# Method 4: conv2d_transpose
d_x_transpose = tf.nn.conv2d_transpose(value=d_out,
filter=w,
output_shape=x_shape,
strides=strides,
padding='VALID')
Computing d_w
in 3 different ways:
# Method 1: TF's autodiff
d_w = tf.gradients(f, w)[0]
# Method 2: manually using conv2d
d_w_manual = tf_NHWC_to_HWIO(tf.nn.conv2d(input=x,
filter=tf_NHWC_to_HWIO(d_out),
strides=strides,
padding='VALID'))
# Method 3: conv2d_backprop_filter
d_w_backprop_filter = tf.nn.conv2d_backprop_filter(input=x,
filter_sizes=w_shape,
out_backprop=d_out,
strides=strides,
padding='VALID')
Please see the full scripts for the implementation of tf_rot180
, tf_pad_to_full_conv2d
, tf_NHWC_to_HWIO
. In the scripts, we check that the final output values of different methods are the same; a numpy implementation is also available.
Any linear transform, including convolution, can be represented as a matrix. A transpose convolution can be interpreted as transposing the convolution matrix before applying it. For example, consider the simple 1D convolution with kernel size of 3 and stride of 2.
If we transpose the convolution matrix and apply it to a 3 element vector we get the transpose convolution operation
Now at first, this doesn't look like a convolution operation anymore. But if we insert some zeros into the y vector first we can rewrite this equivalently as
This example demonstrates that the transpose of a strided convolution operator is equivalent to upsampling by a factor of the stride by inserting zeros, then adding some additional padding, and finally performing an unstrided (i.e. stride=1) convolution.
For higher dimensional transpose convolutions, the same upsampling-by-inserting-zeros method is applied to each dimensions before performing an unstrided convolution.
One application for conv2d_transpose is upscaling, here is an example that explains how it works:
a = np.array([[0, 0, 1.5],
[0, 1, 0],
[0, 0, 0]]).reshape(1,3,3,1)
filt = np.array([[1, 2],
[3, 4.0]]).reshape(2,2,1,1)
b = tf.nn.conv2d_transpose(a,
filt,
output_shape=[1,6,6,1],
strides=[1,2,2,1],
padding='SAME')
print(tf.squeeze(b))
tf.Tensor(
[[0. 0. 0. 0. 1.5 3. ]
[0. 0. 0. 0. 4.5 6. ]
[0. 0. 1. 2. 0. 0. ]
[0. 0. 3. 4. 0. 0. ]
[0. 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0. ]], shape=(6, 6), dtype=float64)
Here's a simple explanation of what is going on in a special case that is used in U-Net - that's one of the main use cases for transposed convolution.
We're interested in the following layer:
Conv2DTranspose(64, (2, 2), strides=(2, 2))
What does this layer do exactly? Can we reproduce its work?
Here’s the answer:
Here's an example input and output (see details here and here):
In [15]: X.reshape(n, m)
Out[15]:
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
In [16]: y_resh
Out[16]:
array([[ 0., 0., 1., 1., 2., 2., 3., 3., 4., 4.],
[ 0., 0., 1., 1., 2., 2., 3., 3., 4., 4.],
[ 5., 5., 6., 6., 7., 7., 8., 8., 9., 9.],
[ 5., 5., 6., 6., 7., 7., 8., 8., 9., 9.],
[10., 10., 11., 11., 12., 12., 13., 13., 14., 14.],
[10., 10., 11., 11., 12., 12., 13., 13., 14., 14.]], dtype=float32)
This slide from Stanford's cs231n is useful for our question:
This is the best explanation I've seen online how convolution transpose works is here.
I'll give my own short description. It applies convolution with a fractional stride. In other words spacing out the input values (with zeroes) to apply the filter over a region that's potentially smaller than the filter size.
As for the why one would want to use it. It can be used as a sort of upsampling with learned weights as opposed to bilinear interpolation or some other fixed form of upsampling.
conv2d_transpose() simply transposes the weights and flips them by 180 degrees. Then it applies the standard conv2d(). "Transposes" practically means that it changes the order of the "columns" in the weights tensor. Please check the example below.
Here there is an example that uses convolutions with stride=1 and padding='SAME'. It is a simple case but the same reasoning could be applied to the other cases.
Say we have:
If we perform convolution of the input then the activations of the will have shape: [1,28,28,32].
activations = sess.run(h_conv1,feed_dict={x:np.reshape(image,[1,784])})
Where:
W_conv1 = weight_variable([7, 7, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = conv2d(x, W_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1
To obtain the "deconvolution" or "transposed convolution" we can use conv2d_transpose() on the convolution activations in this way:
deconv = conv2d_transpose(activations,W_conv1, output_shape=[1,28,28,1],padding='SAME')
OR using conv2d() we need to transpose and flip the weights:
transposed_weights = tf.transpose(W_conv1, perm=[0, 1, 3, 2])
Here we change the order of the "colums" from [0,1,2,3] to [0,1,3,2].So from [7, 7, 1, 32] we will obtain a tensor with shape=[7,7,32,1]. Then we flip the weights:
for i in range(n_filters):
# Flip the weights by 180 degrees
transposed_and_flipped_weights[:,:,i,0] = sess.run(tf.reverse(transposed_weights[:,:,i,0], axis=[0, 1]))
Then we can compute the convolution with conv2d() as:
strides = [1,1,1,1]
deconv = conv2d(activations,transposed_and_flipped_weights,strides=strides,padding='SAME')
And we will obtain the same result as before. Also the very same result can be obtained with conv2d_backprop_input() using:
deconv = conv2d_backprop_input([1,28,28,1],W_conv1,activations, strides=strides, padding='SAME')
The results are shown here:
Test of the conv2d(), conv2d_tranposed() and conv2d_backprop_input()
We can see that the results are the same. To see it in a better way please check my code at:
https://github.com/simo23/conv2d_transpose
Here I replicate the output of the conv2d_transpose() function using the standard conv2d().