mask (transformer)
padding mask 一个batch中不同长度的句子需要添加padding变成统一的长度,因此需要使用padding mask功能对padding进行清除操作。 seq.shape: (batch, seqlen) pad_idx : padding值 return.shape: (batch, 1, seqlen) def get_pad_mask(seq, pad_idx): return (seq != pad_idx).unsqueeze(-2) sequence mask 在decoder时,为了防止当前时刻看到未来时刻的信息,需要将未来时刻的信息进行掩码操作。 以下函数会返回一个对角线以下为True的矩阵。 # seq.shape: (batch, seqlen) # return.shape: (1, seqlen, seqlen) def get_subsequent_mask(seq): ''' For masking out the subsequent info. ''' sz_b, len_s = seq.size() subsequent_mask = (1 - torch.triu( torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() return