Backward function in PyTorch

后端 未结 1 1856
北恋
北恋 2020-12-05 05:41

I have some question about pytorch\'s backward function I don\'t think I\'m getting the right output :

import numpy as np
import torch
from torch.autograd imp         


        
相关标签:
1条回答
  • 2020-12-05 06:23

    Please read carefully the documentation on backward() to better understand it.

    By default, pytorch expects backward() to be called for the last output of the network - the loss function. The loss function always outputs a scalar and therefore, the gradients of the scalar loss w.r.t all other variables/parameters is well defined (using the chain rule).

    Thus, by default, backward() is called on a scalar tensor and expects no arguments.

    For example:

    a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
    for i in range(2):
      for j in range(3):
        out = a[i,j] * a[i,j]
        out.backward()
    print(a.grad)
    

    yields

    tensor([[ 2.,  4.,  6.],
            [ 8., 10., 12.]])
    

    As expected: d(a^2)/da = 2a.

    However, when you call backward on the 2-by-3 out tensor (no longer a scalar function) - what do you expects a.grad to be? You'll actually need a 2-by-3-by-2-by-3 output: d out[i,j] / d a[k,l](!)

    Pytorch does not support this non-scalar function derivatives. Instead, pytorch assumes out is only an intermediate tensor and somewhere "upstream" there is a scalar loss function, that through chain rule provides d loss/ d out[i,j]. This "upstream" gradient is of size 2-by-3 and this is actually the argument you provide backward in this case: out.backward(g) where g_ij = d loss/ d out_ij.

    The gradients are then calculated by chain rule d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])

    Since you provided a as the "upstream" gradients you got

    a.grad[i,j] = 2 * a[i,j] * a[i,j]
    

    If you were to provide the "upstream" gradients to be all ones

    out.backward(torch.ones(2,3))
    print(a.grad)
    

    yields

    tensor([[ 2.,  4.,  6.],
            [ 8., 10., 12.]])
    

    As expected.

    It's all in the chain rule.

    0 讨论(0)
提交回复
热议问题