元学习系列(六):神经图灵机详细分析

女生的网名这么多〃 提交于 2020-01-25 02:08:57

神经图灵机是LSTM、GRU的改进版本,本质上依然包含一个外部记忆结构、可对记忆进行读写操作,主要针对读写操作进行了改进,或者说提出了一种新的读写操作思路。

神经图灵机之所以叫这个名字是因为它通过深度学习模型模拟了图灵机,但是我觉得如果先去介绍图灵机的概念,就会搞得很混乱,所以这里主要从神经图灵机改进了LSTM的哪些方面入手进行讲解,同时,由于模型的结构比较复杂,为了让思路更清晰,这次也会分开几个部分进行讲解。

概述

首先我们来看一下神经图灵机是怎么运作的:

在这里插入图片描述
神经图灵机和LSTM一样,在每个时刻接受输入并返回输出,输入首先会通过controller处理,controller把处理过的输入以及一系列参数传给读写头,读写头会根据这些东西,计算权重,并对记忆矩阵memory进行消除、写入、读取操作,最后读头返回读取的记忆给controller,controller就可以根据这个记忆计算该时刻的输出,然后就可以等待下一时刻的输入了。

记忆权重

个人觉得神经图灵机主要改进了LSTM的门结构,以前,我们通过对上一时刻的输出和当前时刻的输入分别进行线性变换并相加,再用sigmoid函数进行处理,得到一个用于记忆或者遗忘的权重向量,再和长期记忆按位相乘。这种权重的计算机制是通过神经网络基于数据进行学习,当然可行。而神经图灵机则从注意力的角度,用更接近人的思维,提出分别从content-based和location-based的角度进行注意力计算。

事实上,神经图灵机比起lstm取得更好的效果,主要是因为他其中的location-based的注意力计算,这个location-based的注意力计算有什么用呢,举个例子来说,以前的lstm,你给他一杯奶茶,他可能不能一下子回忆起奶茶是什么,只能勉强回忆起这东西有点像茶、有点像牛奶,然后通过神经网络计算出奶茶或许就是0.5的茶加0.5的奶,换句话说,lstm只能通过相似性或相关性进行模糊匹配访问信息,并不能做到精准的检索信息。

而现在神经图灵机则改进了这种回忆能力,使得模型能够快速回忆起指定内容,而不仅仅是相似的内容,所以如果你给神经图灵机一杯奶茶,他可能一下子就回忆起这个奶茶就是三天前喝过的那个饮料,三天前这个时间就是通过location-based的注意力计算得出的。

content-based addressing

接下来就详细看看神经图灵机两种注意力计算。我们要知道,权重其实就是起到一个聚焦的作用,在读写过程中,告诉模型应该注意记忆中哪个位置的信息。神经图灵机的权重计算结合了两种注意力计算,第一种就是常见的通过输入和记忆的相似度计算权重,称为content-based addressing(为什么叫addressing,我觉得这可以理解成注意力机制中的聚焦):

K[u,v]=uvuvK[u,v] = \frac{u*v}{||u||*||v||}

wtc(i)exp(βtK[kt,Mt(i)])jexp(βtK[kt,Mt(j)])w_t^c(i) \leftarrow \frac{exp(\beta_t K[k_t, M_t(i)])}{\sum_j exp(\beta_t K[k_t, M_t(j)])}

比如目前输入为kt,记忆矩阵为Mt,然后我们就计算kt和Mt中每一个记忆片段之间的cosine距离作为相似度,最后再进行归一化。式子看着复杂,但实质上就是用cosine距离作为相似度的衡量指标计算权重。

location-based addressing

content-based addressing可以帮助模型实现基于相似性的回忆能力,但如果希望模型能够回忆起特定的内容,还需要location-based addressing,也就是准确回忆起某一时刻(位置)的记忆的能力。

我们看看location-based addressing的公式:

w~t(i)=j=0N1wtg(j)st(ij)\widetilde w_t(i) = \sum_{j=0} ^{N-1} w_t^g(j) s_t(i-j)

这里的意思其实就是,新的权重,等于整个权重向量各个元素的重新线性组合。上式中的st是controller输出的,但是最主要的问题是,权重的重新线性组合是如何反映出location-based的location呢,这才是关键,这里我说一下我个人的理解。

比如说目前记忆矩阵中有四段记忆,时刻t的权重分别为0.1、0.1、0.1、0.7,这意味着在时刻t模型更关注第四段记忆,如果我现在希望模型在这一刻更关注第二段记忆,对第二段记忆可以进行以下线性变换:

w~t(2)=00.1+00.1+00.1+10.7=0.7\widetilde w_t(2) = 0*0.1 + 0*0.1 + 0*0.1 + 1*0.7 = 0.7

对第一、三、四段记忆可以进行以下线性变换:

w~t(1,3,4)=10.1+00.1+00.1+00.7=0.1\widetilde w_t(1,3,4) = 1*0.1 + 0*0.1 + 0*0.1 + 0*0.7 = 0.1

通过这样的变换,焦点就回到第二段记忆,而变换使用的权重是由controller给出的,所以controller就可以通过数据学习到这种所谓的基于位置的回忆能力。

权重计算完整过程

在这里插入图片描述
现在我们可以来看看完整的权重计算过程了。首先,controller接受输入,输出的kt与记忆矩阵进行content-based addressing。接下来是一个关键的处理,有时候我们需要相似度进行聚焦,有时候则不需要,所以模型会通过controller输出的gt去判断到底需要什么程度content-based addressing:

wtg=gtwtc+(1gt)wt1w_t^g = g_t w_t^c + (1-g_t) w_{t-1}

如果gt接近1,就意味着content-based addressing是有用的,如果gt接近0,那么就相当于跳过了content-based addressing,直接对kt计算location-based addressing。

利用controller输出的st,可以继续进行location-based addressing,但得到的仍然不是最终的输出,因为考虑到一种情况,如果location-based addressing的结果比较平均,就会导致权重的差异性不够突出,或者这样理解,考虑一种最极端的情况,每个位移权重都等于相邻几个位置的均值,就会导致最终的权重差异过小,所以最后我们还需要增强权值之间的差异程度,也就是sharping,主要利用controller输出的gamma_t,对权重进行指数运算后再归一化:

wt(i)=wt(i)γtjwt(i)γtw_t(i) = \frac{w'_t(i)^{\gamma _t}}{\sum_j w'_t(i)^ {\gamma _t} }

以上就是计算权重的全过程。

读写操作

计算了权重之后,就相当于模型知道每次处理输入的时候,应该把注意力放在哪些地方,所以接下来就可以结合这个权重进行读写操作了。

神经图灵机的读写顺序不重要,我们可以先来看看神经图灵机的读操作,在LSTM中,模型是根据当前的输入,读取更新后的长期记忆中有用的信息,而神经图灵机也是差不多的思路,首先在时间t,记忆矩阵Mt是一个N*M的矩阵,N是每个记忆的位置,M是记忆向量(或者说具体的记忆信息),wt就是时间t的关于这N个记忆的权重向量,那么在时间t,读取的记忆就是:

rt=iNwt(i)Mt(i)r_t = \sum_i^N w_t(i) M_t(i)

本质上就是根据我们之前求出来的焦点,从记忆中提取关注的部分。

然后我们看看写操作,主要包括两个过程,先消除后写入:

M~t(i)Mt1(i)[1wt(i)et]\widetilde M_t(i) \leftarrow M_{t-1}(i) [1-w_t(i)e_t]

Mt(i)M~t(i)+wt(i)atM_t(i) \leftarrow \widetilde M_{t}(i) + w_t(i) a_t

我是这样理解这个过程的,首先是消除记忆的过程,写头(write head)接受了controller的输出后会输出一个消除向量et,长度为N范围从0到1(表示N个记忆每个要遗忘的程度),对第i段记忆,如果要遗忘的程度是0.2,它对应的权重为0.4,那么就对这段记忆进行(1-0.2*0.4)的缩放,从而弱化这段记忆在整体的作用,达到一定程度的消除效果。

至于写入记忆的过程,也是通过写头产生的向量at对记忆进行追加,只是这个at则没有限制在01之间,可以理解为根据输入得到的新记忆,根据这一段记忆的权重,在原记忆的基础上叠加。

最后再总结一下,整个神经图灵机结构包括四部分:controller、读写头和记忆矩阵,controller的作用最重要,包括对输入的处理以及输出各种参数辅助权重的计算,读头则主要负责根据controller的输出从记忆矩阵中读取关注的部分记忆,写头则基于controller的输出,产生擦除向量和写入向量,对记忆矩阵进行消除和写入,经过这一轮的读写操作后,controller再基于读头返回的信息分析模型在该时刻的最终输出。

在github写的自然语言处理入门教程,持续更新:NLPBeginner

在github写的机器学习入门教程,持续更新:MachineLearningModels

想浏览更多关于数学、机器学习、深度学习的内容,可浏览本人博客

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!