《Distilling the Knowledge in a Neural Network》论文笔记

妖精的绣舞 提交于 2019-12-21 15:37:48

1. 概述

这篇文章是比较早的文章,但是很多后序的工作都是源自于此,文章中提出了使用soft target(文章中说的是softmax层的输出)来实现大的模型到小的模型之间的知识迁移,从而提升小模型的性能。对于这里使用soft target而非hard target(如分类的类别目标),其原因是软目标能够提供更多可供训练的信息,而硬目标则会造成梯度上方差的减小。有了软目标的帮助小的模型能够更少的参数与更高的学习率

对于大模型中软目标的值可能相差很大的情况,如果直接使用大模型中的软目标进行训练将会使得小模型过多关注软目标中的较大值,而忽略了其它的相似性信息,因而文章引入了temperature的概念使得软目标的分布区间变小,从而能够学习到更多的信息。

2. 温度参数

上文说到使用temperature参数加在softmax层上对输出的概率分布进行软化,其对应的数学表达为:
qi=exp(zi/T)jexp(zj/T)q_i=\frac{exp(z_i/T)}{\sum_{j}exp(z_j/T)}
在小的模型训练的时候是加大参数TT的,而在训练完成之后是将T设置回正常值T=1T=1的。

在训练的过程中小目标的目标函数是由两部分组成的:

  • 1)大模型用较大TT构造软目标与小模型对应TT参数输出的预测结果的交叉熵损失,这部分的损失由于参数TT的引入导致梯度下降1T2\frac{1}{T^2},因而需要在计算梯度的时候进行补偿,乘以系数T2T^2
  • 2)小目标在参数T=1T=1的时候计算输出与真实硬目标的交叉熵损失,对于这里的提到的两个损失是使用一个参数α\alpha来进行调整的,一般来讲带温度参数TT的损失是占主导的;

对于带温度参数TT的损失,其梯度计算为:
Czi=1T(qipi)=1T(ezj/Tjezj/Tevj/Tjevj/T)\frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{e^{z_j/T}}{\sum_je^{z_j/T}}-\frac{e^{v_j/T}}{\sum_je^{v_j/T}})
如果上式中的温度参数比较高那么,那么上面的梯度计算可以近似描述为:
Czi1T(1+zj/TN+zj/T1+vj/TN+vj/T)\frac{\partial C}{\partial z_i}\approx\frac{1}{T}(\frac{1+z_j/T}{N+z_j/T}-\frac{1+v_j/T}{N+v_j/T})
再进一步假设logits是经过0均值的那么jzj=jvj=0\sum_jz_j=\sum_jv_j=0,那么上面的梯度计算就可以描述为:
Czi1NT2(zivi)\frac{\partial C}{\partial z_i}\approx\frac{1}{NT^2}(z_i-v_i)

3. 总结

这篇文章的原理比较简单,也很容易理解,后序还有很多人基于此在各个领域运用知识蒸馏方法获得小模型。这里推荐一份代码帮助熟悉知识蒸馏的运用:knowledge-distillation-pytorch

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!