Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop
简单总结
主要工作(What)
- “蒸馏”(distillation):把大网络的知识压缩成小网络的一种方法
- “专用模型”(specialist models):对于一个大网络,可以训练多个专用网络来提升大网络的模型表现
具体做法(How)
- 蒸馏:先训练好一个大网络,在最后的softmax层使用合适的温度参数T,最后训练得到的概率称为“软目标”。以这个软目标和真实标签作为目标,去训练一个比较小的网络,训练的时候也使用在大模型中确定的温度参数T
- 专用模型:对于一个已经训练好的大网络,可以训练一系列的专用模型,每个专用模型只训练一部分专用的类以及一个“不属于这些专用类的其它类”,比如专用模型1训练的类包括“显示器”,“鼠标”,“键盘”,...,“其它”;专用模型2训练的类包括“玻璃杯”,“保温杯”,“塑料杯”,“其它“。最后以专用模型和大网络的预测输出作为目标,训练一个最终的网络来拟合这个目标。
意义(Why)
- 蒸馏把大网络压成小网络,这样就可以先在训练阶段花费大精力训练一个大网络,然后在部署阶段以较小的计算代价来产生一个较小的网络,同时保持一定的网络预测表现。
- 对于一个已经训练好的大网络,如果要去做集成的话计算开销是很大的,可以在这个基础上训练一系列专用模型,因为这些模型通常比较小,所以训练会快很多,而且有了这些专用模型的输出可以得到一个软目标,实验证明使用软目标训练可以减小过拟合。最后根据这个大网络和一系列专用模型的输出作为目标,训练一个最终的网络,可以得到不错的表现,而且不需要对大网络做大量的集成计算。
Abstract
提高机器学习算法表现的一个简单方法就是,训练不同模型然后对预测结果取平均。
但是要训练多个模型会带来过高的计算复杂度和部署难度。
可以将集成的知识压缩在单一的模型中。
论文使用这种方法在MNIST上做实验,发现取得了不错的效果。
论文还介绍了一种新型的集成,包括一个或多个完整模型和专用模型,能够学习区分完整模型容易混淆的细粒度的类别。
1 Introduction
昆虫有幼虫期和成虫期,幼虫期主要行为是吸收养分,成虫期主要行为是生长繁殖。
类似地,大规模机器学习应用可以分为训练阶段和部署阶段,训练阶段不要求实时操作,允许训练一个复杂缓慢的模型,这个模型可以是分别训练多个模型的集成,也可以是单独的一个很大的带有强正则比如dropout的模型。
一旦模型训练好,可以用不同的训练,这里称为“蒸馏”,去把知识转移到更适合部署的小模型上。
复杂模型学习区分大量的类,通常的训练目标是最大化正确答案的平均log概率,这么做有一个副作用就是训练模型同时也会给所有的错误答案分配概率,即使这个概率很小,而有一些概率会比其它的大很多。错误答案的相对概率体现了复杂模型的泛化能力。举个例子,宝马的图像被错认为垃圾箱的概率很低,但是这被个错认为垃圾桶的概率相比于被错认为胡萝卜的概率来说,是很大的。(可以认为模型不止学到了训练集中的宝马图像特征,还学到了一些别的特征,比如和垃圾桶共有的一些特征,这样就可能捕捉到在新的测试集上的宝马出现这些的特征,这就是泛化能力的体现)
将复杂模型转为小模型需要保留模型的泛化能力,一个方法就是用复杂模型产生的分类概率作为“软目标”来训练小模型。
当软目标的熵值较高时,相对于硬目标,每个训练样本提供更多的信息,训练样本之间会有更小的梯度方差。
所以小模型经常可以被训练在小数据集上,而且可以使用更高的学习率。
像MNIST这种分类任务,复杂模型可以产生很好的表现,大部分信息分布在小概率的软目标中。
为了规避这个问题,Caruana和他的合作者们使用softmax输出前的units值,而不是softmax后的概率,最小化复杂模型和简单模型的units的平方误差来训练小模型。
而更通用的方法,蒸馏法,先提高softmax的温度参数直到模型能产生合适的软目标。然后在训练小模型匹配软目标的时候使用相同的温度T。
被用于训练小模型的转移训练集可以包括未打标签的数据(可以没有原始的实际标签,因为可以通过复杂模型获取一个软目标作为标签),或者使用原始的数据集,使用原始数据集可以得到更好的表现。
2 Distillation
softmax公式: $ q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{ }exp(z_{j}/T)} $
其中温度参数T通常设置为1,T越大可以得到更“软”的概率分布。
(T越大,不同激活值的概率差异越小,所有激活值的概率趋于相同;T越小,不同激活值的概率差异越大)
(在蒸馏训练的时候使用较大的T的原因是,较小的T对于那些远小于平均激活值的单元会给予更少的关注,而这些单元是有用的,使用较高的T能够捕捉这些信息)
最简单的蒸馏形式就是,训练小模型的时候,以复杂模型得到的“软目标”为目标,采用复杂模型中的较高的T,训练完之后把T改为1。
当部分或全部转移训练集的正确标签已知时,蒸馏得到的模型会更优。一个方法就是使用正确标签来修改软目标。
但是我们发现一个更好的方法,简单对两个不同的目标函数进行权重平均,第一个目标函数是和复杂模型的软目标做一个交叉熵,使用的复杂模型的温度T;第二个目标函数是和正确标签的交叉熵,温度设置为1。我们发现第二个目标函数被分配一个低权重时通常会取得最好的结果。
3 Preliminary experiments on MNIST
net | layers | units of each layer | activation | regularization | test errors |
---|---|---|---|---|---|
single net1 | 2 | 1600 | relu | dropout | 67 |
single net2 | 2 | 800 | relu | no | 146 |
(防止表格黏在一起)
net | large net | small net | temperature | test errors |
---|---|---|---|---|
distilled net | single net1 | single net2 | 20 | 74 |
(第一个表格中是两个单独的网络,一个大网络和一个小网络。)
(第二个表格是使用了蒸馏的方法,先训练大网络,然后根据大网络的“软目标”结果和温度T来训练小网络。)
(可以看到,通过蒸馏的方法将大网络中的知识压缩到小网络中,取得了不错的效果。)
4 Experiments on speech recognition
system | Test Frame Accuracy | Word Error Rate on dev set |
---|---|---|
baseline | 58.9% | 10.9% |
10XEnsemble | 61.1% | 10.7% |
Distilled model | 60.8% | 10.7% |
其中basline的配置为
- 8 层,每层2560个relu单元
- softmax层的单元数为14000
- 训练样本大小约为 700M,2000个小时的语音文本数据
10XEnsemble是对baseline训练10次(随机初始化为不同参数)然后取平均
蒸馏模型的配置为
- 使用的候选温度为{1, 2, 5, 10}, 其中T为2时表现最好
- hard target 的目标函数给予0.5的相对权重
可以看到,相对于10次集成后的模型表现提升,蒸馏保留了超过80%的效果提升
5 Training ensembles of specialists on very big datasets
训练一个大的集成模型可以利用并行计算来训练,训练完成后把大模型蒸馏成小模型,但是另一个问题就是,训练本身就要花费大量的时间,这一节介绍的是如何学习专用模型集合,集合中的每个模型集中于不同的容易混淆的子类集合,这样可以减小计算需求。专用模型的主要问题是容易集中于区分细粒度特征而导致过拟合,可以使用软目标来防止过拟合。
5.1 JFT数据集
JFT是一个谷歌的内部数据集,有1亿的图像,15000个标签。google用一个深度卷积神经网络,训练了将近6个月。
我们需要更快的方法来提升baseline模型。
5.2 专用模型
将一个复杂模型分为两部分,一部分是一个用于训练所有数据的通用模型,另一部分是很多个专用模型,每个专用模型训练的数据集是一个容易混淆的子类集合。这些专用模型的softmax结合所有不关心的类为一类来使模型更小。
为了减少过拟合,共享学习到的低水平特征,每个专用模型用通用模型的权重进行初始化。另外,专用模型的训练样本一半来自专用子类集合,另一半从剩余训练集中随机抽取。
5.3 将子类分配到专用模型
专用模型的子类分组集中于容易混淆的那些类别,虽然计算出了混淆矩阵来寻找聚类,但是可以使用一种更简单的办法,不需要使用真实标签来构建聚类。对通用模型的预测结果计算协方差,根据协方差把经常一起预测的类作为其中一个专用模型的要预测的类别。几个简单的例子如下。
JFT 1: Tea party; Easter; Bridal shower; Baby shower; Easter Bunny; ...
JFT 2: Bridge; Cable-stayed bridge; Suspension bridge; Viaduct; Chimney; ...
JFT 3: Toyota Corolla E100; Opel Signum; Opel Astra; Mazda Familia; ...
5.4 实验表现
system | Conditional Test Accuracy | Test Accuracy |
---|---|---|
baseline | 43.1% | 25.0% |
61 specialist models | 45.9% | 26.1% |
6 Soft Targets as Regularizers
对于前面提到过的,对于大量数据训练好的语音baseline模型,用更少的数据去拟合这个模型的时候,使用软目标可以达到更好的效果,减小过拟合。实验结果如下。
system & training set | Train Frame Accuracy | Test Frame Accuracy |
---|---|---|
baseline(100% training set) | 63.4% | 58.9% |
baseline(3% training set) | 67.3% | 44.5% |
soft targets(3% training set) | 65.4% | 57.0% |
来源:https://www.cnblogs.com/liaohuiqiang/p/9170582.html