写这个博客的原因在于:大部分解释Transformer的文章都只注重讲解Encoder部分,在Encoder中又侧重讲解self-attention原理。为了读者更好地理解整个Transformer的训练过程,我决定结合代码写一篇在理解了Encoder部分怎么理解Decoder模块的博文。
参考文章:https://jalammar.github.io/illustrated-transformer/
参考代码:https://github.com/Kyubyong/transformer
pre: Encoder
根据以上参考文章及代码理解Encoder的self-attention原理非常容易,这里不再赘述。需要说明的是以下维度:
德文输入X.shape:[batch_size, max_len]
英文标注Y.shape:[batch_size, max_len]
Encoder输出维度
[batch_size, max_len, hidden_units]
也就是代码里的[N, T_q, C]
Decoder
在训练过程中,Transformer同所有seq2seq模型一样,会用到source data以及不断生成的target data的部分数据(理解就是RNN的因果关系,训练过程中不像BiRNN一样使用未来数据,因此需要Masking)。
需要说明的是代码中的key masking和query masking是对于文本padding部分的掩盖,目的是使Encoder不过多的关注于padding这种无效信息。
causality
代码中的causality部分是是对未来信息的掩盖。这部分代码位于modules.py中。
if causality:
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(masks)*(-2**32+1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
下面我通过对比Decoder中的self-attention和Encoder-Decoder attention两个模块说明Decoder在代码中是如何具体同时attention源数据及生成数据的。这对理解Decoder如何使用数据很关键。
同Encoder一样,使用多个block叠加:
with tf.variable_scope("num_blocks_{}".format(i)):
block中包含使用源数据的self-attention【目标数据自身关注,因此需要掩盖未来数据来模拟逐词生成、类似于单向RNN】,
和使用生成数据的vanilla attention【目标数据关注于源数据,也就是en关注于de,由于源数据是存在的,因此没有属于未来的数据,不需要进行掩盖未来数据的操作,类似于BiRNN】。
self-attention(自身关注,需掩盖未来数据)
self.dec = multihead_attention(queries=self.dec,
keys=self.dec,
num_units=hp.hidden_units,
num_heads=hp.num_heads,
dropout_rate=hp.dropout_rate,
is_training=is_training,
causality=True,
scope="self_attention")
vanilla attention(关注源数据,causality=False)
self.dec = multihead_attention(queries=self.dec,
keys=self.enc,
num_units=hp.hidden_units,
num_heads=hp.num_heads,
dropout_rate=hp.dropout_rate,
is_training=is_training,
causality=False,
scope="vanilla_attention")
在这个对比中,主要的输入参数不同是:
- keys
- causality
keys输入用来计算关注的权重,在代码中key=value,同时用来计算权重以及attention之后的结果。
self-attention:关注self.dec,也就是自身关注,设置causality=True掩盖训练数据集中的未来数据。
vanilla attention:关注self.enc,也就是关注数据集中的源数据,设置causality=False来取消掩盖未来数据(因为训练集的X是已知的)。
causality的不同,具体代码如本文的第一段代码所示,在此复制过来进行分析:
if causality:
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(masks)*(-2**32+1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
这里主要是使用了
tf.linalg.LinearOperatorLowerTriangular().to_dense()
这个函数生成mask,该函数的作用是将:
1 | 1 | 1 | 1 |
---|---|---|---|
1 | 1 | 1 | 1 |
1 | 1 | 1 | 1 |
1 | 1 | 1 | 1 |
变成:
1 | 0 | 0 | 0 |
---|---|---|---|
1 | 1 | 0 | 0 |
1 | 1 | 1 | 0 |
1 | 1 | 1 | 1 |
而后通过:
paddings = tf.ones_like(masks)*(-2**32+1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
将未来数据的权重设置为无穷小,以达到在训练过程中不关注未来数据的作用。也就是生成第一个词时关注第0个token,生成第二个词时关注第0及第1个token,如上表格所示。
而在vanilla attention中设置causality=False关注源数据的所有token。
来源:CSDN
作者:二里庄
链接:https://blog.csdn.net/weixin_37735081/article/details/104264898