《Batch Spectral Shrinkage for Safe Transfer Learning》论文解析

久未见 提交于 2020-02-05 00:32:16

文章全名为《Catastrophic Forgetting Meets Negative Transfer:Batch Spectral Shrinkage for Safe Transfer Learning》

1、摘要

这篇文章主要针对模型的fine-tune问题进行优化。众所周知,在许多模型的训练中,使用预训练好的模型进行fine-tune可以使模型的训练更加容易且结果更好。但是因为两个原因:灾难性遗忘(catastrophic forgetting)和负面迁移(negative transfer),使得fine-tune的效果降低了。本文提出了一种方法Batch Spectral Shrinkage (BSS),暂且翻译为批量光谱收缩,来克服这一情况。

 

2、介绍

主要介绍这两种导致fine-tune效果变差的原因。

首先是灾难性遗忘,即模型在学习与目标任务相关的信息时,容易突然失去之前所学的知识,导致过度拟合。第二个是负迁移,并非所有预先训练得到的知识都可以跨领域迁移,且不加选择地迁移所有知识对模型是有害的。

这里作者提到增量学习,并指出本文所提出的算法与增量学习的不同之处。增量学习可以学习新数据中的新知识,同时保证旧知识不被遗忘。但是与本文算法的目的不同,增量学习最终的目标是使得模型可以应用于新旧两个任务上,而BSS的目标是只作用于新的任务上。

当然作者也提到了在优化fine-tune的先驱者们提出的算法,L2 指的是在训练时直接正则化参数,L2-SP认为应该使用预训练模型中的信息来正则化某些参数,而非直接正则化所有参数。DELTA使用了注意力机制,使用了feature map中的知识。

基于奇异值分解,作者提出了Batch Spectral Shrinkage (BSS),来指引模型中参数和特征的可迁移性,进而增强模型迁移后的性能。

3、灾难性遗忘和负迁移

3.1 迁移学习中的正则化

这里就不介绍相关工作了,直接介绍算法。

首先明确定义,F是预训练模型与目标任务的共享网络,G是目标任务特定的网络。举例分类任务的话,就是说F用来提取特征,G用来分类,在预训练网络下,这两部分定义为F0和G0。

所以一般网络的损失函数可以定义如下:

其中L为网络本身的损失函数,Ω为对参数或者特征的正则化。

则L2 ,L2-SP,DELTA三种方法的损失函数如下图所示:

因与本算法无关,暂不多介绍。

3.2 负迁移

首先要证明一件事,负迁移是不是真实存在的?

作者在这里设计了实验,在一个数据集上分别用15%,30%,50%,100%的数据来训练,看看效果。

如上图图A,红色的柱是使用了L2-SP的错误率,蓝色的是使用了L2 的错误率。可以看到红色更高,且在数据量越低的时候相比蓝色错误率越严重。既然L2-SP和L2 都是为了解决灾难性遗忘,且L2-SP约束神经网络的参数更接近预训练模型,由此可见负迁移确实存在。

3.2 为什么会有负迁移

作者设计了实验来说明负迁移是如何存在的。

他用同位角来计算矩阵之间的相似度。

也就是说这个角度指的是在两个矩阵中同样重要的两个特征向量之间的角度,同样重要的定义是指在奇异值矩阵中下标一样。当角度越小时,也就是cos值越接近1,说明二者相似性越大,可迁移性也越高。

作者计算了预训练网络(W0)和fine-tune后的网络(W)不同层的参数的同位角。由上图b可以看出,在网络浅层,可迁移性高。网络越深则可迁移性越低。

同时作者计算了网络特征表示fi = F(xi)的奇异值,如上图c、d。实验证明目标域数据量越少,奇异值越高,因此想办法降低奇异值就可以增强网络的泛化能力。

4、算法

其实说来该算法非常简单。

δ是网络提出的特征的奇异值,这里用了-i是因为从后往前数。思想就是把奇异值小(意味着相关性小)的部分给压制住。

因此网络总体的loss就是:

接下来是实验部分,实验部分就不写了,反正是state-of-the-art。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!