🎯代码全部放在GitHub
预训练BERT以及相关代码下载地址:
链接: https://pan.baidu.com/s/1zd6wN7elGgp1NyuzYKpvGQ 提取码: tmp5
🍥我们知道BERT模型的输入有三部分:token embedding
,segment embedding
以及position embedding
。
- 词向量的后续处理
- 先生成Segment Embeddings 和 Position Embeddings,
- 再相加,即Input = Token Embeddings + Segment Embeddings + Position Embeddings
BERT源码分析
transformers库
Self-Attention与Transformer
🍤模型创建
class BERTClass(torch.nn.Module):
def __init__(self):
super(BERTClass, self).__init__()
self.config = BertConfig.from_pretrained('../emb/bert-mini/bert_config.json', output_hidden_states=True)
self.l1 = BertModel.from_pretrained('../emb/bert-mini/pytorch_model.bin', config=self.config)
self.bilstm1 = torch.nn.LSTM(512, 64, 1, bidirectional=True)
self.l2 = torch.nn.Linear(128, 64)
self.a1 = torch.nn.ReLU()
self.l3 = torch.nn.Dropout(0.3)
self.l4 = torch.nn.Linear(64, 14)
def forward(self, ids, mask, token_type_ids):
sequence_output, pooler_output, hidden_states= self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)
# [bs, 200, 256] [bs,256]
bs = len(sequence_output)
h12 = hidden_states[-1][:,0].view(1,bs,256)
h11 = hidden_states[-2][:,0].view(1,bs,256)
concat_hidden = torch.cat((h12,h11),2)
x, _ = self.bilstm1(concat_hidden)
x = self.l2(x.view(bs,128))
x = self.a1(x)
x = self.l3(x)
output = self.l4(x)
return output
net = BERTClass()
net.to(device)
🍣训练模型
def train(epoch,train_iter, test_iter, criterion, num_epochs, optimizer, device):
print('training on', device)
net.to(device)
best_test_acc = 0
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # 设置学习率下降策略
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06) # 余弦退火
for epoch in range(num_epochs):
train_l_sum = torch.tensor([0.0], dtype=torch.float32, device=device)
train_acc_sum = torch.tensor([0.0], dtype=torch.float32, device=device)
n, start = 0, time.time()
for data in tqdm(train_iter):
net.train()
optimizer.zero_grad()
ids = data['ids'].to(device, dtype=torch.long)
mask = data['mask'].to(device, dtype=torch.long)
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
targets = data['targets'].to(device, dtype = torch.float)
y_hat = net(ids, mask, token_type_ids)
loss = criterion(y_hat, targets.long())
loss.backward()
optimizer.step()
with torch.no_grad():
targets = targets.long()
train_l_sum += loss.float()
train_acc_sum += (torch.sum((torch.argmax(y_hat, dim=1) == targets))).float()
n += targets.shape[0]
valid_acc = evaluate_accuracy(test_iter, net, device)
train_acc = train_acc_sum / n
print('epoch %d, loss %.4f, train acc %.3f, valid acc %.3f, '
'time %.1f sec'
% (epoch + 1, train_l_sum / n, train_acc, valid_acc,
time.time() - start))
if valid_acc > best_test_acc:
print('find best! save at model/best.pth')
best_test_acc = valid_acc
torch.save(net.state_dict(), 'model/best.pth')
scheduler.step() # 更新学习率
来源:oschina
链接:https://my.oschina.net/u/4412764/blog/4526987