Transformer翻译模型Decoder详解(Masking)

戏子无情 提交于 2020-02-11 19:39:03

写这个博客的原因在于:大部分解释Transformer的文章都只注重讲解Encoder部分,在Encoder中又侧重讲解self-attention原理。为了读者更好地理解整个Transformer的训练过程,我决定结合代码写一篇在理解了Encoder部分怎么理解Decoder模块的博文。

参考文章:https://jalammar.github.io/illustrated-transformer/
参考代码:https://github.com/Kyubyong/transformer

pre: Encoder

根据以上参考文章及代码理解Encoder的self-attention原理非常容易,这里不再赘述。需要说明的是以下维度:
德文输入X.shape:[batch_size, max_len]
英文标注Y.shape:[batch_size, max_len]
Encoder输出维度
[batch_size, max_len, hidden_units]
也就是代码里的[N, T_q, C]

Decoder

在训练过程中,Transformer同所有seq2seq模型一样,会用到source data以及不断生成的target data的部分数据(理解就是RNN的因果关系,训练过程中不像BiRNN一样使用未来数据,因此需要Masking)。

需要说明的是代码中的key maskingquery masking是对于文本padding部分的掩盖,目的是使Encoder不过多的关注于padding这种无效信息。

causality
代码中的causality部分是是对未来信息的掩盖。这部分代码位于modules.py中。

if causality:
      diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
      tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
      masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
   
      paddings = tf.ones_like(masks)*(-2**32+1)
      outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)

下面我通过对比Decoder中的self-attentionEncoder-Decoder attention两个模块说明Decoder在代码中是如何具体同时attention源数据及生成数据的。这对理解Decoder如何使用数据很关键。
同Encoder一样,使用多个block叠加:

with tf.variable_scope("num_blocks_{}".format(i)):

block中包含使用源数据的self-attention【目标数据自身关注,因此需要掩盖未来数据来模拟逐词生成、类似于单向RNN】,
和使用生成数据的vanilla attention【目标数据关注于源数据,也就是en关注于de,由于源数据是存在的,因此没有属于未来的数据,不需要进行掩盖未来数据的操作,类似于BiRNN】。

self-attention(自身关注,需掩盖未来数据)

self.dec = multihead_attention(queries=self.dec, 
                                keys=self.dec, 
                                num_units=hp.hidden_units, 
                                num_heads=hp.num_heads, 
                                dropout_rate=hp.dropout_rate,
                                is_training=is_training,
                                causality=True, 
                                scope="self_attention")

vanilla attention(关注源数据,causality=False)

self.dec = multihead_attention(queries=self.dec, 
                                keys=self.enc, 
                                num_units=hp.hidden_units, 
                                num_heads=hp.num_heads,
                                dropout_rate=hp.dropout_rate,
                                is_training=is_training, 
                                causality=False,
                                scope="vanilla_attention")

在这个对比中,主要的输入参数不同是:

  • keys
  • causality

keys输入用来计算关注的权重,在代码中key=value,同时用来计算权重以及attention之后的结果。
self-attention:关注self.dec,也就是自身关注,设置causality=True掩盖训练数据集中的未来数据。
vanilla attention:关注self.enc,也就是关注数据集中的源数据,设置causality=False来取消掩盖未来数据(因为训练集的X是已知的)。

causality的不同,具体代码如本文的第一段代码所示,在此复制过来进行分析:

if causality:
      diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
      tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
      masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
   
      paddings = tf.ones_like(masks)*(-2**32+1)
      outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)

这里主要是使用了
tf.linalg.LinearOperatorLowerTriangular().to_dense()
这个函数生成mask,该函数的作用是将:

1 1 1 1
1 1 1 1
1 1 1 1
1 1 1 1

变成:

1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1

而后通过:

paddings = tf.ones_like(masks)*(-2**32+1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)

将未来数据的权重设置为无穷小,以达到在训练过程中不关注未来数据的作用。也就是生成第一个词时关注第0个token,生成第二个词时关注第0及第1个token,如上表格所示。

而在vanilla attention中设置causality=False关注源数据的所有token。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!