新闻文本分类之旅 BERT

女生的网名这么多〃 提交于 2020-10-01 14:59:32

在这里插入图片描述
在这里插入图片描述

🎯代码全部放在GitHub
预训练BERT以及相关代码下载地址:
链接: https://pan.baidu.com/s/1zd6wN7elGgp1NyuzYKpvGQ 提取码: tmp5

🍥我们知道BERT模型的输入有三部分:token embeddingsegment embedding以及position embedding

  • 词向量的后续处理
    • 先生成Segment Embeddings 和 Position Embeddings,
    • 再相加,即Input = Token Embeddings + Segment Embeddings + Position Embeddings

在这里插入图片描述
在这里插入图片描述
BERT源码分析
transformers库
Self-Attention与Transformer
🍤模型创建




class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.config = BertConfig.from_pretrained('../emb/bert-mini/bert_config.json', output_hidden_states=True)
        self.l1 = BertModel.from_pretrained('../emb/bert-mini/pytorch_model.bin', config=self.config)
        self.bilstm1 = torch.nn.LSTM(512, 64, 1, bidirectional=True)
        self.l2 = torch.nn.Linear(128, 64)
        self.a1 = torch.nn.ReLU()
        self.l3 = torch.nn.Dropout(0.3)
        self.l4 = torch.nn.Linear(64, 14)
    
    def forward(self, ids, mask, token_type_ids):
        sequence_output, pooler_output, hidden_states= self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)
        # [bs, 200, 256]  [bs,256]
        bs = len(sequence_output)
        h12 = hidden_states[-1][:,0].view(1,bs,256)
        h11 = hidden_states[-2][:,0].view(1,bs,256)
        concat_hidden = torch.cat((h12,h11),2)
        x, _ = self.bilstm1(concat_hidden)
        x = self.l2(x.view(bs,128))
        x = self.a1(x)
        x = self.l3(x)
        output = self.l4(x)
        return output

net = BERTClass()
net.to(device)

🍣训练模型

def train(epoch,train_iter, test_iter, criterion, num_epochs, optimizer, device):
    print('training on', device)
    net.to(device)
    best_test_acc = 0
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 设置学习率下降策略
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06)  # 余弦退火
    for epoch in range(num_epochs):
        train_l_sum = torch.tensor([0.0], dtype=torch.float32, device=device)
        train_acc_sum = torch.tensor([0.0], dtype=torch.float32, device=device)
        n, start = 0, time.time()
        for data in tqdm(train_iter):
            net.train()
            optimizer.zero_grad()
            ids = data['ids'].to(device, dtype=torch.long)
            mask = data['mask'].to(device, dtype=torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            y_hat = net(ids, mask, token_type_ids)
            loss = criterion(y_hat, targets.long())
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                targets = targets.long()
                train_l_sum += loss.float()
                train_acc_sum += (torch.sum((torch.argmax(y_hat, dim=1) == targets))).float()
                n += targets.shape[0]
        valid_acc = evaluate_accuracy(test_iter, net, device)
        train_acc = train_acc_sum / n
        print('epoch %d, loss %.4f, train acc %.3f, valid acc %.3f, '
              'time %.1f sec'
              % (epoch + 1, train_l_sum / n, train_acc, valid_acc,
                 time.time() - start))
        if valid_acc > best_test_acc:
            print('find best! save at model/best.pth')
            best_test_acc = valid_acc
            torch.save(net.state_dict(), 'model/best.pth')
        scheduler.step()  # 更新学习率
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!