基于度量的元学习和基于优化的元学习

☆樱花仙子☆ 提交于 2020-08-11 13:20:44

在介绍两种主流的元学习方法之前,先简要概括下元学习与监督学习的区别。

监督学习:

  1. 只在一个任务上做训练;
  2. 只有一个训练集和一个测试集;
  3. 学习的是样本之间的泛化能力;
  4. 要求训练数据和测试数据满足独立同分布;
  5. 监督学习的训练和测试过程分别为train和test;

小样本条件下监督学习的过程如图所示。

监督学习过程

元学习:

  1. 元学习是面向多个任务做联合训练;
  2. 每个任务都有训练集和测试集;
  3. 学习到的是任务之间的泛化能力;
  4. 要求新任务与训练任务再分布上尽可能一致;
  5. 元学习的训练和测试过程分别叫做Meta-train和Meta-test;

小样本条件下元学习过程如下图所示。更多元学习的细节内容可以参见:小样本学习方法专栏

元学习过程

介绍两种类型的元学习方法

在元学习过程中元学习器Meta-learner是学习过程的重要一环,承接着Meta-train和Meta-test两个阶段。针对元学习器的搭建,近年来主流的方法模型如下图所示。

主流元学习模型

从基于度量和基于优化的角度可以将上述模型分为两大类,基于度量的元学习模型:MatchingNet、ProtoNet、RelationNet等,基于优化的元学习模型:MAML、Reptile、LEO等。下面对上述模型一一的做个简要的介绍:

基于度量的三个元学习模型

基于度量的三个元学习模型的示意图如下图所示:

  • MatchingNet:Support set经过特征提取后,在embedding空间中利用Cosine来度量,通过对测试样本进行计算匹配程度来实现分类;
  • ProtoNet:利用聚类思想,将Support set投影到一个度量空间,在欧式距离度量的基础上获取向量均值,对测试样本计算到每个原型的距离,实现分类;
  • RelationNet:关系网络提出的relation module结构替换了MatchingNet和ProtoNet中的Cosine和欧式距离度量,使其成为一种学习的非线性分类器用于判断关系,实现分类。
基于度量的元学习模型

基于优化的三个元学习模型

  • 基于优化的三个元学习模型示意图如图所示(LEO示意图省略)
  • MAML:Model Agnostic Meta-Learning通过优化二阶梯度的方式,学习一个通用的初始化模型,使得模型面对新任务时,进行少次迭代便可收敛;
  • Reptile:命名可能单纯为了和MAML对应(Reptile:爬行动物;mammal:哺乳动物),由示意图可知,与 MAML不同的是,Reptile是一个一阶的基于梯度的元学习算法,每次base model的参数更新是在每个task的一阶梯度上做的,不过进行了多次;
  • LEO:Latent Embedding Optimization方法针对模型的超高维参数空间,小样本情况下几步梯度下降导致的过拟合的问题,将模型高维参数空间学习到一个低维嵌入,在低维空间实施梯度下降来实现对问题进行改善;
  • Fine-tune:除了上述三个典型的模型之外,基于优化的元学习还包括预训练方法(fine-tune),与MAML等学习一个通用的初始模型的思想很接近,基于预训练的是在一批已有的任务上或者一个大型数据集上学习到的模型,面对新任务时,将其模型参数做为初始点,在新任务进行微调。而上述三中元学习方法是学习一个通用的模型,使得这个模型在面对旧任务和新任务时都可以在几步梯度下降后达到相应任务的较优解。此外,由于预训练模型在面对新任务时更新了参数,让原先在旧任务上训练好的参数被新的信息覆盖,容易产生灾难性遗忘问题,此部分内容参考《百面深度学习》。
基于优化的元学习模型

两种元学习方法存在的问题

非参数方法是指在Meta-test阶段,每个新任务的训练数据没有使用训练带参函数的方法,即在新任务上没有学习过程。参数方法则需要在新任务上继续调参,比如使用梯度下降进行权重参数更新。请允许我在这里不严谨的将基于度量的方法称为非参数方法,将基于优化的方法称为参数化方法。

基于优化的元学习存在的问题

基于优化的元学习方法即参数化方法在使用梯度下降法更新权重时,由于优化器选择(如:SGD、Adam等)和学习率lr设定的限制,通常需要更新多步达到较优的点,使得当模型在面对新任务时,学习过程缓慢;此外在训练小样本情况下,更新权重的过程中容易过拟合。

基于度量的元学习存在的问题

基于度量的元学习方法即非参数化方法虽然没有上述苦恼,但也绝对不是完美的,matrix learning的核心在于损失函数的设置,然而A Metric Learning Reality Check这篇文章却给这个领域泼了一盆冷水,该领域十三年来无进展,这也说明了度量学习领域正处于发展的瓶颈期,某些创新提出的loss在人脸数据集上涨了很多,但在其他任务上效果可能会变得更差,这也说明度量学习方法的鲁棒性不高,对数据集相对来说比较挑剔。

以上内容仅代表个人观点,欢迎交流。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!