mask (transformer)

☆樱花仙子☆ 提交于 2020-09-29 16:22:30

padding mask

一个batch中不同长度的句子需要添加padding变成统一的长度,因此需要使用padding mask功能对padding进行清除操作。

seq.shape: (batch, seqlen)

pad_idx : padding值

return.shape: (batch, 1, seqlen) 

def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)

sequence mask

在decoder时,为了防止当前时刻看到未来时刻的信息,需要将未来时刻的信息进行掩码操作。

以下函数会返回一个对角线以下为True的矩阵。

# seq.shape: (batch, seqlen)
# return.shape: (1, seqlen, seqlen)

def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''
    sz_b, len_s = seq.size()
    subsequent_mask = (1 - torch.triu(
        torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
    return subsequent_mask

inputs: 

data = torch.randn(3, 5)

return :

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!