I have some question about pytorch\'s backward function I don\'t think I\'m getting the right output :
import numpy as np
import torch
from torch.autograd imp
Please read carefully the documentation on backward() to better understand it.
By default, pytorch expects backward()
to be called for the last output of the network - the loss function. The loss function always outputs a scalar and therefore, the gradients of the scalar loss w.r.t all other variables/parameters is well defined (using the chain rule).
Thus, by default, backward()
is called on a scalar tensor and expects no arguments.
For example:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
yields
tensor([[ 2., 4., 6.], [ 8., 10., 12.]])
As expected: d(a^2)/da = 2a
.
However, when you call backward
on the 2-by-3 out
tensor (no longer a scalar function) - what do you expects a.grad
to be? You'll actually need a 2-by-3-by-2-by-3 output: d out[i,j] / d a[k,l]
(!)
Pytorch does not support this non-scalar function derivatives. Instead, pytorch assumes out
is only an intermediate tensor and somewhere "upstream" there is a scalar loss function, that through chain rule provides d loss/ d out[i,j]
. This "upstream" gradient is of size 2-by-3 and this is actually the argument you provide backward
in this case: out.backward(g)
where g_ij = d loss/ d out_ij
.
The gradients are then calculated by chain rule d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
Since you provided a
as the "upstream" gradients you got
a.grad[i,j] = 2 * a[i,j] * a[i,j]
If you were to provide the "upstream" gradients to be all ones
out.backward(torch.ones(2,3))
print(a.grad)
yields
tensor([[ 2., 4., 6.], [ 8., 10., 12.]])
As expected.
It's all in the chain rule.