Tensor
Tensor(张量)是Pytorch中的基础计算单位,和numpy中的ndarray一样都是表示一个多维矩阵,不同的地方在于:Tensor既可以在CPU上运算也可以在GPU上进行运算,而ndarray只能在CPU上进行计算。
Tensor有三个重要的属性:
data
:保存张量的值;grad
:保存该张量的梯度;grad_fn
:指向一个用于反向传播计算输入梯度的Function对象;
在创建Tensor默认是不使用梯度的,如果需要进行梯度计算需要设置属性:requires_grad = True
Autograd
Autograd是Pytorch的核心模块,该模块实现了深度学习算法的反向传播(BP)求导过程。在Pytorch中所有作用在Tensor上的操作,Autograd都能为其提供自动微分求导的操作。
在创建Tensor时,同设置属性requires_grad
为True
声明该Tensor需要计算梯度。
用户手动创建的Tensor的grand_fn
属性默认是None
。
在张量进行操作之后grad_fn
就回被赋值为一个新的函数,该函数指向创建了该Tensor的Function对象。
Tensor 和 Function 共同组成一个非循环图,通过grad_fn
属性记录了Tensor完整的计算历史。
Autograd过程解析
可以使用python的 dir()
函数来查看变量z
的详情:
['__abs__',
...
'__array_priority__',
...
'__lshift__',
'__lt__',
...
'__ne__',
'__neg__',
...
'__reversed__',
'__rfloordiv__',
'__rmul__',
'__rpow__',
...
'data',
'data_ptr',
'dense_dim',
...
'diag',
'diag_embed',
...
'grad',
'grad_fn',
'gt',
'gt_',
'half',
'hardshrink',
'histc',
'ifft',
...
'is_coalesced',
'is_complex',
'is_contiguous',
'is_cuda',
'is_distributed',
'is_floating_point',
'is_leaf',
'is_nonzero',
...
'is_signed',
'is_sparse',
'isclose',
'item',
...
'var',
'view',
...
'zero_']
除去python中特有的属性和方法之外,有一个属性is_leaf
比较特殊。在pytorch中手动创建的Tensor(如:x,y)是叶子节点
即 is_leaf
的属性是True
;通过运算得到的Tensor(如:z)是非叶子节点
即is_leaf
的属性为:False
。
当调用 backward()
方法时,Autograd通过 grad_fn
更新Tensor的 grad
。
如上图所示,z.grad_fn
是 <AddBackward0 at 0x23f3275e7f0>
,从名字可以看出这是一个关于 Add
运算的反向传播求梯度函数。
可以使用python的 dir()
函数来查看变量z.grad_fn
的详情:
可以看到 z.grad_fn
有一个属性 next_functions
:
从上图可以看到:z.grad_fn.next_functions
是一个长度为2的元组(因为 z = x**2 + y**3
),内容是关于 Pow
运算的反向传播函数。
继续深挖:
发现 z.grad_fn.next_functions[0][0].next_functions
的类型是 AccumulateGrad
,在Pytorch中 AccumulateGrad
类型是 叶子节点
类型,也就是 计算图的终点
。AccumulateGrad
类中有一个 .variable
属性指向叶子节点。
观察发现 x_leaf.variable
和 x
的值相同,那么它们会不会是同一个变量?可以通过id()
函数查看下。
测试发现它们的 id
一模一样。
结论
所以可以总结出Pytorch自动求导的过程:
- 当反向求导执行
z.backward()
时会调用z.grad_fn
引用的Function
对象进行反向传播求导; - 这个操作会遍历
grad_fn
的next_functions
,然后分别遍历其中的Function
执行求导操作,该递归操作直到AccumulateGrad(叶子节点)
结束; - 计算出的结果保存在他们对应的
variable
所引用的变量(x, y)的grad
属性里面; - 求导结束,所有的叶子节点的
grad
都得到相应的更新;
引用
- https://github.com/zergtant/pytorch-handbook
来源:https://blog.csdn.net/dendi_hust/article/details/100016611