深度学习算法原理——LSTM

断了今生、忘了曾经 提交于 2020-09-27 04:04:31

1. 概述

循环神经网络RNN一文中提及到了循环神经网络RNN存在长距离依赖的问题,长短期记忆(Long Short-Term Memory,LSTM)网络便是为了解决RNN中存在的梯度爆炸的问题而提出。在LSTM网络中,主要依靠引入“门”机制来控制信息的传播。

2. 算法原理

2.1. LSTM的网络结构

LSTM的网络结构如下所示(图片来自参考文献):

在这里插入图片描述
与循环神经网络RNN相比,LSTM的网络结构要复杂的多。

在LSTM网络中,通过引入三个门来控制信息的传递,这三个门分别为遗忘门(forget gate),输入门(input gate)和输出门(output gate)。门机制是LSTM中重要的概念,那么什么是“门”以及门机制在LSTM中是如何解决长距离依赖的问题的。

2.2. 门机制

现实中的“门”通常解释为出入口,在LSTM网络的门也是一种出入口,但是是控制信息的出入口。门的状态通常有三种状态,分别为全开(信息通过概率为1),全闭(信息通过概率为0)以及半开(信息通过概率介于0和1之间)。在这里,我们发现对于全开,全闭以及半开三种状态下的信息通过可以通过概率来表示,在神经网络中,sigmoid函数也是一个介于0和1之间的表示,可以应用到LSTM中门的计算中。

2.3. LSTM的计算过程

如下是LSTM的网络结构的具体形态,如下所示(图片来自邱锡鹏老师的课件):

在这里插入图片描述

其中, c t − 1 c_{t-1} ct1表示的是 t − 1 t-1 t1时刻的cell state(注:关于cell state,查了多个版本的中文翻译,有翻译为“细胞状态”,有翻译成“单元状态”,邱老师使用的是内部状态,没有一个明确的中文翻译,故在此使用英文), h t − 1 h_{t-1} ht1表示的是 t − 1 t-1 t1时刻的hidden state(注:与前面的cell state对应), x t x_t xt表示的是 t t t时刻的输入, f t f_t ft表示的是遗忘门, i t i_t it表示的是输入门, c ~ t \tilde{c}_t c~t表示的是候选值(candidate values), o t o_t ot表示的是输出门。

从图中的数据流向得到的计算流程如下所示:

  1. 利用 t − 1 t-1 t1时刻的hidden state h t − 1 h_{t-1} ht1计算遗忘门 f t f_t ft的结果, f t f_t ft的计算公式如下所示

f t = σ ( W f x t + U f h t − 1 + b f ) f_t=\sigma \left ( W_fx_t+U_fh_{t-1}+b_f \right ) ft=σ(Wfxt+Ufht1+bf)

  1. 利用 t − 1 t-1 t1时刻的hidden state h t − 1 h_{t-1} ht1计算输入门 i t i_t it的结果, i t i_t it的计算公式如下所示

i t = σ ( W i x t + U i h t − 1 + b i ) i_t=\sigma \left ( W_ix_t+U_ih_{t-1}+b_i \right ) it=σ(Wixt+Uiht1+bi)

  1. 利用 t − 1 t-1 t1时刻的hidden state h t − 1 h_{t-1} ht1计算候选值 c ~ t \tilde{c}_t c~t的结果, c ~ t \tilde{c}_t c~t的计算公式如下所示

c ~ t = t a n h ( W c x t + U c h t − 1 + b c ) \tilde{c}_t=tanh \left ( W_cx_t+U_ch_{t-1}+b_c \right ) c~t=tanh(Wcxt+Ucht1+bc)

  1. 利用 t − 1 t-1 t1时刻的hidden state h t − 1 h_{t-1} ht1计算输出门 o t o_t ot的结果, o t o_t ot的计算公式如下所示

o t = σ ( W o x t + U o h t − 1 + b o ) o_t=\sigma \left ( W_ox_t+U_oh_{t-1}+b_o \right ) ot=σ(Woxt+Uoht1+bo)

  1. 根据 t t t时刻的cell state c t c_t ct,这里会使用到 t − 1 t-1 t1时刻的cell state c t − 1 c_{t-1} ct1,遗忘门 f t f_t ft,输入门 i t i_t it和候选值 c ~ t \tilde{c}_t c~t c t c_t ct的计算公式如下所示

c t = f t ⊙ c t − 1 + i t ⊙ c ~ t c_t=f_t\odot c_{t-1}+i_t\odot \tilde{c}_t ct=ftct1+itc~t

上述的公式是由前面的1,2,3部分的公式组成,也是LSTM网络中的关键的部分,对该公式,我们从如下的几个部分来理解:

  • f t ⊙ c t − 1 f_t\odot c_{t-1} ftct1,使用遗忘门 f t f_t ft t − 1 t-1 t1时刻下的cell state c t − 1 c_{t-1} ct1遗忘;
  • i t ⊙ c ~ t i_t\odot \tilde{c}_t itc~t,首先是 c ~ t \tilde{c}_t c~t表示的是通过 t t t时刻的输入和 t − 1 t-1 t1时刻的hidden state h t − 1 h_{t-1} ht1需要增加的信息,与输入门 i t i_t it结合起来就表示整体需要增加的信息;
  • 两部分结合表示的是 t t t时刻下的cell state下需要从 t − 1 t-1 t1时刻下的cell state中保留的部分信息以及 t t t时刻下新增信息的总和。
  1. 根据输出门 o t o_t ot和cell state c t c_t ct计算外部状态 h t h_t ht h t h_t ht的计算公式如下所示

h t = o t ⊙ t a n h ( c t ) h_t=o_t\odot tanh\left ( c_t \right ) ht=ottanh(ct)

参考文献

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!