本文同步发布于公众号阿黎投喂社:一只沉迷吃喝渴望变强的代码仔,关注阿黎看阿黎是怎么变强(秃)的吧~
论文链接:https://arxiv.org/pdf/1901.02860.pdf
导语
NLP领域中为了捕捉单词之间的依赖关系,比较常用的结构是RNN和Transformer,但是他们都受限于固定长度的上下文依赖,当信息量超过特定长度,模型的能力非常有限。为了解决这个问题,文中收到Vanilla Transformer的启发,引入segment机制,结合RNN和Transformer的特点,将循环机制引入Transformer,提高了Transformer捕获信息的能力,并且在多个公开数据集上超过了SOTA的效果。那么我们来看一下作者是如何改进的吧~
背景
在NLP领域中,Language Model需要捕捉到单词之间的依赖关系,比较常见的结构是RNN,但是由于梯度爆炸和梯度消失的原因RNN比较难训练。LSTM在RNN的基础上加了遗忘门,使得隐藏状态由联乘变成了累加状态,解决了梯度爆炸和梯度消失的问题。但是由于遗忘门的限制LSTM对上下文的记忆能力有限,部分研究表示LSTM对上下文的记忆仅限于200个词语。
2017年Google提出Transformer架构,Transformer架构摒弃了RNN,引入Position Embedding和self-attention机制对序列进行全局处理,使得输入的句子可以根据其所在位置进行相互影响,效果超过了被大家广泛使用的LSTM和RNN。具体有关Transformer的解释The Annotated Transformer非常详细的介绍了Transformer的原理和源码,大家可以参考这篇文章。然而Transformer在预训练阶段设置了固定序列长度max_len的上下文,finetune阶段捕捉的信息无法超过max_len的上下文依赖。 而且Transformer在确定max_len时是根据连续token的长度来确定的,不考虑句子的语义边界,因此模型在预测句子刚开始的几个token的时候会因为上下文的信息较少而出现效果较差的情况。
为了解决这些问题,文中提出了Transformer-XL,下面我们来看看Transformer-XL是怎么解决这些问题的吧~
Vanilla Transformer
为什么要讲这个模型呢?因为Transformer-XL是在这个模型的基础上改进的。Vanilla Transformer是2018年由Al-Rfou等人提出,目的是为了解决Transformer可以捕捉的信息无法超过max_len的问题。
Vanilla Transformer模型训练的时候按照前n-1个字符预测第n个字符,因此训练的时候会将文本划分为多个segments,每个segments单独处理,如图a所示。预测时候采用滑动窗口方式,使用前n-1个字符预测第n个字符,然后左移窗口,如图b所示。
这样做的优点是可以可以保证位置i可以获取上下文的信息,使得可以捕获的长度增加了N。但是因为segment的划分没有考虑到文本的语义,导致分割出来的segment可能是不完整的 。而且滑动窗口的方式也会导致预测的时候效率比较低。
Transformer-XL
1. 引入循环机制
和Vanilla Transformer的思路一样,Transformer-XL同样引入了segment机制,只不过和Vanilla Transformer不同的是,Transformer-XL引入了段之间循环的机制。训练过程如下图a所示,当前段隐藏层的输入不仅包含当前层的隐藏层(灰色部分)还包括上一段隐藏层的输出(绿色部分)。这两段输出会被拼接用于计算当前段的QKV矩阵。
例如两段相邻的segment,那么第t+1段中第n个位置的隐藏状态的计算方式如下:
是拼接操作,SG表示停止计算梯度,是我们缓存的上一个segment的隐藏状态,第t+1个segment的第n-1个位置的隐藏状态是由第t个segment的第n-1个位置的隐藏状态和该位置的隐藏状态拼接而成的。而第n个的位置的QKV矩阵分别是由该位置的隐藏状态,拼接后的隐藏状态进行线性变换得到的,这表示了第t+1个segment处计算Token之间的相关性也考虑到了上一个segment的信息, 将Transformer的信息感知域扩大了一倍,并且当前segment可以接收到上一个segment的信息, 这也是我们题目中的循环机制的体现。
因为文中在计算隐藏状态的时候缓存了上一个segment的隐藏状态,在evaluate阶段,我们不需要计算每一个位置的隐藏状态,因此Transformer-XL在evaluate阶段的速度比Vanilla Transformer快1800倍。实际上,理论上我们可以根据GPU的空间,缓存尽可能多的隐藏状态,这样可以进一步扩宽Transformer的信息感知域。
2. 相对位置编码
在传统的Transformer中位置编码是输入句子中的绝对位置编码,然后将word embedding和position embedding相加。但是引入了segment概念之后,随着滑动窗口的滑动,不同的segment在同一个位置的Token的position embedding将会完全一样,为了避免这样的情况发生,文中采用了相对位置编码R其中表示相对位置差i的相对位置编码。
在传统Transformer中query和key的attention score如下所示:
上述计算方式是将将进行了分解,其中分别是i和j的position embedding。因为文中引入了相对位置编码,那么attention score的计算方式将要发生相应的改变。
- 首先是将绝对位置编码替换了相对位置变量,
- 由于没有绝对位置的概念,在传统的Transformer中表示位置encoding的线性变化,文中公式c中的被替换为,d中的被替换为,u和v都是通过训练可以学习到的。
那么相对位置的attention score计算方式将如下所示:
在新的计算方式中,公式a表示上下文之间的文本相关性信息,公式b表示文本的文本偏置,即相对于当前位置的位置偏差,公式c表示全局的内容偏差,用于衡量key的重要程度,公式d表示全局的位置偏差,用来衡量query和key的位置偏差的重要程度。
3. 整体计算公式
有了以上的改进,Transformer-XL模型整体的计算公式如下所示,这里是只有single-head sttention的情况,其中表示初始化的隐藏状态就是word embedding。
需要注意的是在计算attention score的时候需要计算,这个是O(length)2的复杂度,但是实际上i-j的范围在0~length之间,因此可以预先计算好,然后在计算的时候直接用就可以了,这样就可以把时间复杂度降到O(length)。具体如何计算呢?
假设M是memory的长度,L是当前segment的长度,那么i-j的范围就是0~M+L-1,那么Query矩阵的计算方式如下所示:
那么对于attention score中公式b的结果将会是一个L*(M+L)的矩阵,那么B的结果如下所示,文中的公式稍微有点错误,我做了一丢丢修正。
我们进一步定义矩阵B为:
那么对于attention_score中公式d的计算也可以推导如下所示:
实验效果
文中在WikiText-2003,enwiki8,text8,One Billion Word和Penn Treebank上分别进行了试验,试验效果如下图所示,在各数据集上都超过了state-of-art的效果。
结论和思考
Transformer-XL是XLNet的基础,相比于传统的Transformer,Transformer-XL引入了循环机制和相对位置编码,同时缓存了隐藏状态,使得Inference的速度很快,并且在多个数据集上也超过了state-of-art的效果。本文从结构和理论上对transformer-xl进行了讲解,但是因为文中需要缓存上一个segment的隐藏状态,这样提高了inference的速度,但是想必也要增加很多内存,对于已经属于比较大型的transformer来说,感觉很多贫穷的玩家应该是玩不起的了。但是文中从上下文的长度和信息丰富程度出发,给transformer中引入循环机制,还是很有启发意义的。
来源:oschina
链接:https://my.oschina.net/u/4269975/blog/4320336