torch

mask (transformer)

☆樱花仙子☆ 提交于 2020-09-29 16:22:30
padding mask 一个batch中不同长度的句子需要添加padding变成统一的长度,因此需要使用padding mask功能对padding进行清除操作。 seq.shape: (batch, seqlen) pad_idx : padding值 return.shape: (batch, 1, seqlen) def get_pad_mask(seq, pad_idx): return (seq != pad_idx).unsqueeze(-2) sequence mask 在decoder时,为了防止当前时刻看到未来时刻的信息,需要将未来时刻的信息进行掩码操作。 以下函数会返回一个对角线以下为True的矩阵。 # seq.shape: (batch, seqlen) # return.shape: (1, seqlen, seqlen) def get_subsequent_mask(seq): ''' For masking out the subsequent info. ''' sz_b, len_s = seq.size() subsequent_mask = (1 - torch.triu( torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() return

PyTorch:数据读取1

 ̄綄美尐妖づ 提交于 2020-09-23 12:43:14
-柚子皮- 什么是Datasets? 在输入流水线中,准备数据的代码是这么写的 data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True) datasets.CIFAR10 就是一个 Datasets 子类, data 是这个类的一个实例。 为什么要定义Datasets? PyTorch 提供了一个工具函数 torch.utils.data.DataLoader 。通过这个类,我们可以让数据变成mini-batch,且在准备 mini-batch 的时候可以多线程并行处理,这样可以加快准备数据的速度。 Datasets 就是构建这个类的实例的参数之一。 DataLoader的使用参考[]。 -柚子皮- 自定义Datasets 框架 import torch.utils.data as data class CustomDataset(data.Dataset): # 继承data.Dataset """Custom data.Dataset compatible with data.DataLoader.""" def __init__(self, filename, data_info, oth_params): """Reads source and target