笔记在一定程度上参考了AI科技评论的用于少次关系学习的神经网络雪球机制
1.文章的创新点
关系增长的关系抽取
关系抽取(RelationExtraction)是自然语言处理当中的一个重要研究课题,其探究如何从文本中抽取结构化的关系事实。例如,从句子“比尔盖茨是微软的创始人”中,我们可以抽取出(比尔盖茨,创始人,微软)这样一个关系三元组,并用于知识图谱补全等下游任务中。
与关系抽取相关的工作有很多,但他们大多针对预定义的关系类型,即给定一个人为定义好的关系集合,在抽取时仅考虑集合内的关系类型。然而,我们面临着开放式的关系增长,随着新领域、新知识的出现,关系类型也在不断增多。因此,我们需要能够应对关系增长的关系抽取模型。
知识度量:通过RSN来在已有关系上的大规模数据上训练距离度量,迁移到新的关系类型上,利用新关系的少量数据作为启动种子,从大规模的无监督数据中挖掘有用信息,挖掘越多的有用信息,我们就能得到越好的关系抽取模型。
2. Related Work
2.1 Three different kinds of data
目前的研究当中,主要有以下几种关系抽取的场景,他们所针对的关系类型和利用的数据都有所不同:
- 有监督的关系抽取(Supervised Relation Extraction):其针对预定义的关系集合,使用大规模的监督数据。
- 半监督的关系抽取(Semi-Supervised Relation Extraction):其针对的也是预定义的关系集合,希望使用相对较少的监督数据,在大量无监督数据的帮助下,能够取得与有监督关系抽取类似的效果。
- 少次学习关系抽取(Few-Shot Relation Extraction):其针对的新的(没见过的)关系类型,通过在已有关系类型上的大规模数据预先训练,再快速迁移到新关系类型的少量数据上,达到少次学习的目的。
- 自启动关系抽取(Bootstrapping Relation Extraction):其面向的也是开放的关系场景,对于新的关系类型,仅给定少量的启动样本,以迭代的方式从大规模的数据中挖掘更多的信息,从而得到更加强大的关系抽取模型。
从上面的分析中可以看出,这些方法涉及到了三种类型的数据:
- 在已有关系类型上的大规模监督数据(large-scale existing relations data)
- 对于新关系的少量标注数据(The new relation few-shot ins)
- 以及大规模的无监督数据,从语料库中得到的unseen或unlabelled data(large-scale Unlabelled Corpora data)
我们希望能够充分的利用这三种数据,于是作者提出了一种新的关系抽取方法——Snowball方法
2.2 Nerual Snowball
三种数据类型的Nerual Snowball模型如下
如上图所示,Neural Snowball通过在已有关系上的大规模数据上训练距离度量,迁移到新的关系类型上,利用新关系的少量数据作为启动种子,从大规模的无监督数据中挖掘有用信息,挖掘越多的有用信息,我们就能得到越好的关系抽取模型。
3. Methodology
Nerual Snowball的构成如下
- 种子集Seed: 对于一个新的关系类型,给定少量样本作为Seed集,在每一轮迭代iteration后,种子集都会扩充经过选择的unlabelled ins(上文中的大规模无监督数据),得到的新的再次参与迭代。
- 候选集C1: C1中的instance是通过远程监督(distant supervision)挖掘得到的,如下图所示,可以看出如果我们新的关系是founder,那么对于种子集中的少量正确标注样例"Bill founder Microsoft",远程监督只会获得含有Bill 和 Microsoft的实体对所包含的句子,因此容易引入并不含有founder这层关系的句子(Bill mentioned Microsoft)。这些不能体现新的关系"founder"的句子是我们不需要的,所以C1称之为候选集:
- RSN: 判断远程监督挖掘的信息是否是有用信息的过滤器
- Relation Classifier: 是我们模型最终想要得到的二分类器。从C1过滤后的instances与labelled instance有监督训练这个relation classifer。用这个分类器对candidate1做分类,得到candidate2。
- C2: 因为从C1得到的数据并不是完全过滤干净的,也就是说二分类器Relation Classifier模型还不是很强,所以C2也是候选candidate的集合。因此需要再加一个fliter。经过再次过滤后就基本放心了通过二分类器得到的句子(Steven Jobs Apple),再将该句作为待选句子扩充入种子集参与下一轮迭代。
Neural Snowball 的整个流程如下:
- 输入:一个新的关系类型,以及少量的标注数据(启动种子)
- 目标:训练一个该关系类型的二分类器。用二分类器是因为这样更具可扩展性,当关系类型增加的时候,可以将多个二分类器放在一起使用。
- 训练:以启动种子开始,迭代式的从无监督数据中挖掘有用信息。
每一轮迭代主要分为两个阶段:
(1) 利用远监督获取待选句子;
(2) 利用新的关系分类器获取待选句子。
远监督(Distant Supervision)是指,如果已有数据告诉我们,实体h、t之间有关系r,我们就找到所有包含h、t的句子,并假设他们真的表达了关系r。第一步获取了新的训练数据之后,Neural Snowball会训练新的关系分类器,这个新的分类器会从无监督数据中挖掘它认为属于关系r的数据,这些新数据可以帮助训练更好的分类器。
4.Neural Modules
Nerual Snowball有两个关键的components:RSN和Relation Classifier
4.1 RSN
- 输入是两个instance,比如上图通过远程监督获得的两句话,其中的instance:“Bill founder Microsoft” 和 “Bill mentioned Microsoft”(显然这两个Instance并不是表示同一种relation的)
- 输出是0或1
Structure of RSN
RSN由两个encoder和一个distant function组成,结构如下图,其输入两个句子,输出这两个句子是否表达的是同一种关系。我们在已有关系的大规模数据上预先训练RSN,并将它用在Neural Snowball中,对所有从无监督数据中选出来的候选数据,用RSN将它们与启动种子进行比较,仅留下置信度较高的样本。
RSN中的encoder:输入instance,输出它们的representation vectors(表示向量)。其中这两个encoder是权值共享(parameter sharing)的,也就是如果用CNN作编码器,那么卷积核就是权值共享,加快算力。(权值共享:CNN中的权值共享理解)
RSN中的distance function就是用来计算similarity的:
可被当作加权的L2 distance范数,w和b式训练得到的,越高的输出得分表明关系越近。
4.2 Relation Classifier
RC中有neural encoder,用来将新的关系类型转化成real-valued vector;一个线性层可以得到输入的instance是属于relation的概率
如果需要分类多个关系,可以使用N个RC二分类器,因为新的关系类型(raw data)不断增加所以作者不使用N-way classifier。
4.3 Pre-training and Fine-tuning
Pre-training
预训练就是训RSN和RC的网络模型,在以后的迭代过程中这些参数就不改变了。
我们通过已标注的数据集(existing labeled dataset)对其进行有监督训练。
对于RSN,先采样(没错就是上面那个large-scale corpus,最右边的那一种数据)中具有相同关系类型或不同关系类型的instance pairs(均可),接着train RSN with cross entropy loss.
对于RC,通过从(seed set)中采样minibatch个正样例,从中采样minibatch个负样例从而对RC中的linear layer 参数W和b训练优化(估计是Fully-connected layer),loss表达式:
Fine-tuning
因为预训练就是训RSN和RC的网络模型,在以后的迭代过程中这些参数就不改变了。所以微调只针对baseline-model进行微调,禁止套娃不展开。
5.瞅瞅代码
RSN的训练:encode是RSN中的编码函数,forward_infer是采用文中的方法进行计算score,另外forward_infer_sort是对应Snowball类下phase1部分的method B方法。但是无论是那种forward,他们都是集成在forward方法下。具体可以看一下注释
class Siamese(nn.Module):
def __init__(self, sentence_encoder, hidden_size=230, drop_rate=0.5, pre_rep=None, euc=True):
nn.Module.__init__(self)
self.sentence_encoder = sentence_encoder # Should be different from main sentence encoder !!!
self.hidden_size = hidden_size
# self.fc1 = nn.Linear(hidden_size * 2, hidden_size * 2)
# self.fc2 = nn.Linear(hidden_size * 2, 1)
self.fc = nn.Linear(hidden_size, 1)
self.cost = nn.BCELoss(reduction="none")
self.drop = nn.Dropout(drop_rate)
self._accuracy = 0.0
self.pre_rep = pre_rep
self.euc = euc
def forward(self, data, num_size, num_class, threshold=0.5):
# view : 将x处理成num_class行,每行元素为num_size行1列
x = self.sentence_encoder(data).contiguous().view(num_class, num_size, -1)
# view: x1,x2,y1,y2处理成1行,每行(也就是该行)有hidden_size个元素(hidden_size列),//是整除(防小数)
# 其中 [:, :num_size//2]是遍历所有行(class),取到每行的第num_size//2个元素之前的所有元素
# [:, num_size//2:]是遍历所有行,取到每行的第num_size//2个元素之后的所有元素,相当于x1和x2分开了num_size
# [:num_class//2,:]是遍历所有列,取到每列的第num_size//2个元素之前的所有元素,相当于y1和y2分开了所有的num_class
x1 = x[:, :num_size//2].contiguous().view(-1, self.hidden_size)
x2 = x[:, num_size//2:].contiguous().view(-1, self.hidden_size)
y1 = x[:num_class//2,:].contiguous().view(-1, self.hidden_size)
y2 = x[num_class//2:,:].contiguous().view(-1, self.hidden_size)
# y1 = x[0].contiguous().unsqueeze(0).expand(x.size(0) - 1, -1, -1).contiguous().view(-1, self.hidden_size)
# y2 = x[1:].contiguous().view(-1, self.hidden_size)
label = torch.zeros((x1.size(0) + y1.size(0))).long().cuda()
label[:x1.size(0)] = 1 #x1的label全标签为1
z1 = torch.cat([x1, y1], 0)
z2 = torch.cat([x2, y2], 0)
if self.euc:
dis = torch.pow(z1 - z2, 2)
dis = self.drop(dis)
score = torch.sigmoid(self.fc(dis).squeeze())
else:
z = z1 * z2
z = self.drop(z)
z = self.fc(z).squeeze()
# z = torch.cat([z1, z2], -1)
# z = F.relu(self.fc1(z))
# z = self.fc2(z).squeeze()
score = torch.sigmoid(z)
self._loss = self.cost(score, label.float()).mean()
pred = torch.zeros((score.size(0))).long().cuda()
pred[score > threshold] = 1
self._accuracy = torch.mean((pred == label).type(torch.FloatTensor))
pred = pred.cpu().detach().numpy()
label = label.cpu().detach().numpy()
self._prec = float(np.logical_and(pred == 1, label == 1).sum()) / float((pred == 1).sum() + 1)
self._recall = float(np.logical_and(pred == 1, label == 1).sum()) / float((label == 1).sum() + 1)
def encode(self, dataset, batch_size=0):
if self.pre_rep is not None:
return self.pre_rep[dataset['id'].view(-1)]
if batch_size == 0:
x = self.sentence_encoder(dataset)
else:
total_length = dataset['word'].size(0)
max_iter = total_length // batch_size
if total_length % batch_size != 0:
max_iter += 1
x = []
for it in range(max_iter):
scope = list(range(batch_size * it, min(batch_size * (it + 1), total_length)))
with torch.no_grad():
_ = {'word': dataset['word'][scope], 'mask': dataset['mask'][scope]}
if 'pos1' in dataset:
_['pos1'] = dataset['pos1'][scope]
_['pos2'] = dataset['pos2'][scope]
_x = self.sentence_encoder(_)
x.append(_x.detach())
x = torch.cat(x, 0) #concatenate
return x
#使用method A,有阈值
def forward_infer(self, x, y, threshold=0.5, batch_size=0):
x = self.encode(x, batch_size=batch_size)
support_size = x.size(0)
y = self.encode(y, batch_size=batch_size)
# a.squeeze(N) 就是在a中指定位置N加上一个维数为1的维度
x = x.unsqueeze(1) #N = 1
y = y.unsqueeze(0) #N = 0
if self.euc:
dis = torch.pow(x - y, 2) # L2 distance
score = torch.sigmoid(self.fc(dis).squeeze(-1)).mean(0) #nn.linear自有权重和bias,得到score
else:
z = x * y
z = self.fc(z).squeeze(-1)
score = torch.sigmoid(z).mean(0)
pred = torch.zeros((score.size(0))).long().cuda()
pred[score > threshold] = 1
pred = pred.view(support_size, -1).sum(0)
pred[pred < 1] = 0
pred[pred > 0] = 1
return pred
# 使用sort方法
def forward_infer_sort(self, x, y, batch_size=0):
x = self.encode(x, batch_size=batch_size)
support_size = x.size(0)
y = self.encode(y, batch_size=batch_size)
x = x.unsqueeze(1)
y = y.unsqueeze(0)
if self.euc:
dis = torch.pow(x - y, 2)
score = torch.sigmoid(self.fc(dis).squeeze(-1)).mean(0)
else:
z = x * y
z = self.fc(z).squeeze(-1)
score = torch.sigmoid(z).mean(0)
pred = []
for i in range(score.size(0)):
pred.append((score[i], i))
pred.sort(key=lambda x: x[0], reverse=True)
return pred
Snowball部分的代码,很多是用来读取数据的部分,重点在_forward_train下面开始了第三部分从input,phase1,phase2的操作。另外代码原作者也详细注释了relation classifier的负采样过程。
class Snowball(nrekit.framework.Model):
def __init__(self, sentence_encoder, base_class, siamese_model, hidden_size=230, drop_rate=0.5, weight_table=None, pre_rep=None, neg_loader=None, args=None):
nrekit.framework.Model.__init__(self, sentence_encoder)
self.hidden_size = hidden_size
self.base_class = base_class
self.fc = nn.Linear(hidden_size, base_class)
self.drop = nn.Dropout(drop_rate)
self.siamese_model = siamese_model
# self.cost = nn.BCEWithLogitsLoss()
self.cost = nn.BCELoss(reduction="none")
# self.cost = nn.CrossEntropyLoss()
self.weight_table = weight_table
self.args = args
self.pre_rep = pre_rep
self.neg_loader = neg_loader
# def __loss__(self, logits, label):
# onehot_label = torch.zeros(logits.size()).cuda()
# onehot_label.scatter_(1, label.view(-1, 1), 1)
# return self.cost(logits, onehot_label)
# def __loss__(self, logits, label):
# return self.cost(logits, label)
def forward_base(self, data):
batch_size = data['word'].size(0)
x = self.sentence_encoder(data) # (batch_size, hidden_size)
x = self.drop(x)
x = self.fc(x) # (batch_size, base_class)
x = torch.sigmoid(x)
if self.weight_table is None:
weight = 1.0
else:
weight = self.weight_table[data['label']].unsqueeze(1).expand(-1, self.base_class).contiguous().view(-1)
label = torch.zeros((batch_size, self.base_class)).cuda()
label.scatter_(1, data['label'].view(-1, 1), 1) # (batch_size, base_class)
loss_array = self.__loss__(x, label)
self._loss = ((label.view(-1) + 1.0 / self.base_class) * weight * loss_array).mean() * self.base_class
# self._loss = self.__loss__(x, data['label'])
_, pred = x.max(-1)
self._accuracy = self.__accuracy__(pred, data['label'])
self._pred = pred
def forward_baseline(self, support_pos, query, threshold=0.5):
'''
baseline model
support_pos: positive support set
support_neg: negative support set
query: query set
threshold: ins whose prob > threshold are predicted as positive
'''
# train
self._train_finetune_init()
# support_rep = self.encode(support, self.args.infer_batch_size)
support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
# self._train_finetune(support_rep, support['label'])
self._train_finetune(support_pos_rep)
# test
query_prob = self._infer(query, batch_size=self.args.infer_batch_size).cpu().detach().numpy()
label = query['label'].cpu().detach().numpy()
self._baseline_accuracy = float(np.logical_or(np.logical_and(query_prob > threshold, label == 1), np.logical_and(query_prob < threshold, label == 0)).sum()) / float(query_prob.shape[0])
if (query_prob > threshold).sum() == 0:
self._baseline_prec = 0
else:
self._baseline_prec = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((query_prob > threshold).sum())
self._baseline_recall = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((label == 1).sum())
if self._baseline_prec + self._baseline_recall == 0:
self._baseline_f1 = 0
else:
self._baseline_f1 = float(2.0 * self._baseline_prec * self._baseline_recall) / float(self._baseline_prec + self._baseline_recall)
self._baseline_auc = sklearn.metrics.roc_auc_score(label, query_prob)
if self.args.print_debug:
print('')
sys.stdout.write('[BASELINE EVAL] acc: {0:2.2f}%, prec: {1:2.2f}%, rec: {2:2.2f}%, f1: {3:1.3f}, auc: {4:1.3f}'.format( \
self._baseline_accuracy * 100, self._baseline_prec * 100, self._baseline_recall * 100, self._baseline_f1, self._baseline_auc))
print('')
def __dist__(self, x, y, dim):
return (torch.pow(x - y, 2)).sum(dim)
def __batch_dist__(self, S, Q):
return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3)
def forward_few_shot_baseline(self, support, query, label, B, N, K, Q):
support_rep = self.encode(support, self.args.infer_batch_size)
query_rep = self.encode(query, self.args.infer_batch_size)
support_rep.view(B, N, K, -1)
query_rep.view(B, N * Q, -1)
NQ = N * Q
# Prototypical Networks
proto = torch.mean(support_rep, 2) # Calculate prototype for each class
logits = -self.__batch_dist__(proto, query)
_, pred = torch.max(logits.view(-1, N), 1)
self._accuracy = self.__accuracy__(pred.view(-1), label.view(-1))
return logits, pred
# def forward_few_shot(self, support, query, label, B, N, K, Q):
# for b in range(B):
# for n in range(N):
# _forward_train(self, support_pos, None, query, distant, threshold=0.5):
#
# '''
# support_rep = self.encode(support, self.args.infer_batch_size)
# query_rep = self.encode(query, self.args.infer_batch_size)
# support_rep.view(B, N, K, -1)
# query_rep.view(B, N * Q, -1)
# '''
#
# proto = []
# for b in range(B):
# for N in range(N)
#
# NQ = N * Q
#
# # Prototypical Networks
# proto = torch.mean(support_rep, 2) # Calculate prototype for each class
# logits = -self.__batch_dist__(proto, query)
# _, pred = torch.max(logits.view(-1, N), 1)
#
# self._accuracy = self.__accuracy__(pred.view(-1), label.view(-1))
#
# return logits, pred
def _train_finetune_init(self):
# init variables and optimizer
self.new_W = Variable(self.fc.weight.mean(0) / 1e3, requires_grad=True)
self.new_bias = Variable(torch.zeros((1)), requires_grad=True)
self.optimizer = optim.Adam([self.new_W, self.new_bias], self.args.finetune_lr, weight_decay=self.args.finetune_wd)
self.new_W = self.new_W.cuda()
self.new_bias = self.new_bias.cuda()
# 对relation classfier的训练
def _train_finetune(self, data_repre, learning_rate=None, weight_decay=1e-5):
'''
train finetune classifier with given data
data_repre: sentence representation (encoder's output)
label: label
'''
self.train()
optimizer = self.optimizer
if learning_rate is not None:
optimizer = optim.Adam([self.new_W, self.new_bias], learning_rate, weight_decay=weight_decay)
# hyperparameters
max_epoch = self.args.finetune_epoch
batch_size = self.args.finetune_batch_size
# dropout
data_repre = self.drop(data_repre)
# train
if self.args.print_debug:
print('')
for epoch in range(max_epoch):
max_iter = data_repre.size(0) // batch_size
if data_repre.size(0) % batch_size != 0:
max_iter += 1
order = list(range(data_repre.size(0)))
random.shuffle(order)
for i in range(max_iter):
x = data_repre[order[i * batch_size : min((i + 1) * batch_size, data_repre.size(0))]]
# batch_label = label[order[i * batch_size : min((i + 1) * batch_size, data_repre.size(0))]]
# neg sampling
# ---------------------
batch_label = torch.ones((x.size(0))).long().cuda()
neg_size = int(x.size(0) * 1)
neg = self.neg_loader.next_batch(neg_size)
neg = self.encode(neg, self.args.infer_batch_size)
x = torch.cat([x, neg], 0)
batch_label = torch.cat([batch_label, torch.zeros((neg_size)).long().cuda()], 0)
# ---------------------
# Relation Classifier
x = torch.matmul(x, self.new_W) + self.new_bias # (batch_size, 1)
x = torch.sigmoid(x)
# iter_loss = self.__loss__(x, batch_label.float()).mean()
weight = torch.ones(batch_label.size(0)).float().cuda()
weight[batch_label == 0] = self.args.finetune_weight #1 / float(max_epoch)
iter_loss = (self.__loss__(x, batch_label.float()) * weight).mean()
optimizer.zero_grad()
iter_loss.backward(retain_graph=True)
optimizer.step()
if self.args.print_debug:
sys.stdout.write('[snowball finetune] epoch {0:4} iter {1:4} | loss: {2:2.6f}'.format(epoch, i, iter_loss) + '\r')
sys.stdout.flush()
self.eval()
def _add_ins_to_data(self, dataset_dst, dataset_src, ins_id, label=None):
'''
add one instance from dataset_src to dataset_dst (list)
dataset_dst: destination dataset
dataset_src: source dataset
ins_id: id of the instance
'''
dataset_dst['word'].append(dataset_src['word'][ins_id])
if 'pos1' in dataset_src:
dataset_dst['pos1'].append(dataset_src['pos1'][ins_id])
dataset_dst['pos2'].append(dataset_src['pos2'][ins_id])
dataset_dst['mask'].append(dataset_src['mask'][ins_id])
if 'id' in dataset_dst and 'id' in dataset_src:
dataset_dst['id'].append(dataset_src['id'][ins_id])
if 'entpair' in dataset_dst and 'entpair' in dataset_src:
dataset_dst['entpair'].append(dataset_src['entpair'][ins_id])
if 'label' in dataset_dst and label is not None:
dataset_dst['label'].append(label)
def _add_ins_to_vdata(self, dataset_dst, dataset_src, ins_id, label=None):
'''
add one instance from dataset_src to dataset_dst (variable)
dataset_dst: destination dataset
dataset_src: source dataset
ins_id: id of the instance
'''
dataset_dst['word'] = torch.cat([dataset_dst['word'], dataset_src['word'][ins_id].unsqueeze(0)], 0)
if 'pos1' in dataset_src:
dataset_dst['pos1'] = torch.cat([dataset_dst['pos1'], dataset_src['pos1'][ins_id].unsqueeze(0)], 0)
dataset_dst['pos2'] = torch.cat([dataset_dst['pos2'], dataset_src['pos2'][ins_id].unsqueeze(0)], 0)
dataset_dst['mask'] = torch.cat([dataset_dst['mask'], dataset_src['mask'][ins_id].unsqueeze(0)], 0)
if 'id' in dataset_dst and 'id' in dataset_src:
dataset_dst['id'] = torch.cat([dataset_dst['id'], dataset_src['id'][ins_id].unsqueeze(0)], 0)
if 'entpair' in dataset_dst and 'entpair' in dataset_src:
dataset_dst['entpair'].append(dataset_src['entpair'][ins_id])
if 'label' in dataset_dst and label is not None:
dataset_dst['label'] = torch.cat([dataset_dst['label'], torch.ones((1)).long().cuda()], 0)
def _dataset_stack_and_cuda(self, dataset):
'''
stack the dataset to torch.Tensor and use cuda mode
dataset: target dataset
'''
if (len(dataset['word']) == 0):
return
dataset['word'] = torch.stack(dataset['word'], 0).cuda()
if 'pos1' in dataset:
dataset['pos1'] = torch.stack(dataset['pos1'], 0).cuda()
dataset['pos2'] = torch.stack(dataset['pos2'], 0).cuda()
dataset['mask'] = torch.stack(dataset['mask'], 0).cuda()
dataset['id'] = torch.stack(dataset['id'], 0).cuda()
def encode(self, dataset, batch_size=0):
if self.pre_rep is not None:
return self.pre_rep[dataset['id'].view(-1)]
if batch_size == 0:
x = self.sentence_encoder(dataset)
else:
total_length = dataset['word'].size(0)
max_iter = total_length // batch_size
if total_length % batch_size != 0:
max_iter += 1
x = []
for it in range(max_iter):
scope = list(range(batch_size * it, min(batch_size * (it + 1), total_length)))
with torch.no_grad():
_ = {'word': dataset['word'][scope], 'mask': dataset['mask'][scope]}
if 'pos1' in dataset:
_['pos1'] = dataset['pos1'][scope]
_['pos2'] = dataset['pos2'][scope]
_x = self.sentence_encoder(_)
x.append(_x.detach())
x = torch.cat(x, 0)
return x
def _infer(self, dataset, batch_size=0):
'''
get prob output of the finetune network with the input dataset
dataset: input dataset
return: prob output of the finetune network
'''
x = self.encode(dataset, batch_size=batch_size)
x = torch.matmul(x, self.new_W) + self.new_bias # (batch_size, 1)
x = torch.sigmoid(x)
return x.view(-1)
def _forward_train(self, support_pos, query, distant, threshold=0.5):
'''
snowball process (train)
support_pos: support set (positive, raw data)
support_neg: support set (negative, raw data)
query: query set
distant: distant data loader
threshold: ins with prob > threshold will be classified as positive
threshold_for_phase1: distant ins with prob > th_for_phase1 will be added to extended support set at phase1
threshold_for_phase2: distant ins with prob > th_for_phase2 will be added to extended support set at phase2
'''
# hyperparameters
snowball_max_iter = self.args.snowball_max_iter
sys.stdout.flush()
candidate_num_class = 20
candidate_num_ins_per_class = 100
sort_num1 = self.args.phase1_add_num
sort_num2 = self.args.phase2_add_num
sort_threshold1 = self.args.phase1_siamese_th
sort_threshold2 = self.args.phase2_siamese_th
sort_ori_threshold = self.args.phase2_cl_th
# get neg representations with sentence encoder
# support_neg_rep = self.encode(support_neg, batch_size=self.args.infer_batch_size)
# init
self._train_finetune_init()
# support_rep = self.encode(support, self.args.infer_batch_size)
# positive的raw data进行编码以representation
support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
# self._train_finetune(support_rep, support['label'])
self._train_finetune(support_pos_rep)
self._metric = []
# copy
original_support_pos = copy.deepcopy(support_pos)
# snowball
exist_id = {}
if self.args.print_debug:
print('\n-------------------------------------------------------')
for snowball_iter in range(snowball_max_iter):
if self.args.print_debug:
print('###### snowball iter ' + str(snowball_iter))
# phase 1: expand positive support set from distant dataset (with same entity pairs)
## get all entpairs and their ins in positive support set ins is instance
old_support_pos_label = support_pos['label'] + 0
entpair_support = {}
entpair_distant = {}
for i in range(len(support_pos['id'])): # only positive support
entpair = support_pos['entpair'][i] #实体对
exist_id[support_pos['id'][i]] = 1
if entpair not in entpair_support:
if 'pos1' in support_pos:
entpair_support[entpair] = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': []}
else:
entpair_support[entpair] = {'word': [], 'mask': [], 'id': []}
self._add_ins_to_data(entpair_support[entpair], support_pos, i)
## pick all ins with the same entpairs in distant data and choose with siamese network
self._phase1_add_num = 0 # total number of snowball instances
self._phase1_total = 0
for entpair in entpair_support:
raw = distant.get_same_entpair_ins(entpair) # ins with the same entpair
if raw is None:
continue
if 'pos1' in support_pos: #以字典储存
entpair_distant[entpair] = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': [], 'entpair': []}
else:
entpair_distant[entpair] = {'word': [], 'mask': [], 'id': [], 'entpair': []}
for i in range(raw['word'].size(0)):
if raw['id'][i] not in exist_id: # don't pick sentences already in the support set
self._add_ins_to_data(entpair_distant[entpair], raw, i)
self._dataset_stack_and_cuda(entpair_support[entpair])
self._dataset_stack_and_cuda(entpair_distant[entpair])
if len(entpair_support[entpair]['word']) == 0 or len(entpair_distant[entpair]['word']) == 0:
continue
# 比较entpair_support和entpair_distant中的实体对的相似度,决定候选集C1的取舍
pick_or_not = self.siamese_model.forward_infer_sort(entpair_support[entpair], entpair_distant[entpair], batch_size=self.args.infer_batch_size)
# pick_or_not = self.siamese_model.forward_infer_sort(original_support_pos, entpair_distant[entpair], threshold=threshold_for_phase1)
# pick_or_not = self._infer(entpair_distant[entpair]) > threshold
# -- method B: use sort --
for i in range(min(len(pick_or_not), sort_num1)):
if pick_or_not[i][0] > sort_threshold1:
iid = pick_or_not[i][1]
self._add_ins_to_vdata(support_pos, entpair_distant[entpair], iid, label=1)
exist_id[entpair_distant[entpair]['id'][iid]] = 1
self._phase1_add_num += 1
self._phase1_total += entpair_distant[entpair]['word'].size(0)
'''
if 'pos1' in support_pos:
candidate = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': [], 'entpair': []}
else:
candidate = {'word': [], 'mask': [], 'id': [], 'entpair': []}
self._phase1_add_num = 0 # total number of snowball instances
self._phase1_total = 0
for entpair in entpair_support:
raw = distant.get_same_entpair_ins(entpair) # ins with the same entpair
if raw is None:
continue
for i in range(raw['word'].size(0)):
if raw['id'][i] not in exist_id: # don't pick sentences already in the support set
self._add_ins_to_data(candidate, raw, i)
if len(candidate['word']) > 0:
self._dataset_stack_and_cuda(candidate)
pick_or_not = self.siamese_model.forward_infer_sort(support_pos, candidate, batch_size=self.args.infer_batch_size)
for i in range(min(len(pick_or_not), sort_num1)):
if pick_or_not[i][0] > sort_threshold1:
iid = pick_or_not[i][1]
self._add_ins_to_vdata(support_pos, candidate, iid, label=1)
exist_id[candidate['id'][iid]] = 1
self._phase1_add_num += 1
self._phase1_total += candidate['word'].size(0)
'''
## build new support set
# print('---')
# for i in range(len(support_pos['entpair'])):
# print(support_pos['entpair'][i])
# print('---')
# print('---')
# for i in range(support_pos['id'].size(0)):
# print(support_pos['id'][i])
# print('---')
support_pos_rep = self.encode(support_pos, batch_size=self.args.infer_batch_size)
# support_rep = torch.cat([support_pos_rep, support_neg_rep], 0)
# support_label = torch.cat([support_pos['label'], support_neg['label']], 0)
## finetune
# print("Fine-tune Init")
self._train_finetune_init()
self._train_finetune(support_pos_rep)
if self.args.eval:
self._forward_eval_binary(query, threshold)
# self._metric.append(np.array([self._f1, self._prec, self._recall]))
if self.args.print_debug:
print('\nphase1 add {} ins / {}'.format(self._phase1_add_num, self._phase1_total))
# phase 2: use the new classifier to pick more extended support ins
self._phase2_add_num = 0
candidate = distant.get_random_candidate(self.pos_class, candidate_num_class, candidate_num_ins_per_class)
## -- method 1: directly use the classifier --
candidate_prob = self._infer(candidate, batch_size=self.args.infer_batch_size)
## -- method 2: use siamese network --
pick_or_not = self.siamese_model.forward_infer_sort(support_pos, candidate, batch_size=self.args.infer_batch_size)
## -- method A: use threshold --
'''
self._phase2_total = candidate_prob.size(0)
for i in range(candidate_prob.size(0)):
# if (candidate_prob[i] > threshold_for_phase2) and not (candidate['id'][i] in exist_id):
if (pick_or_not[i]) and (candidate_prob[i] > threshold_for_phase2) and not (candidate['id'][i] in exist_id):
exist_id[candidate['id'][i]] = 1
self._phase2_add_num += 1
self._add_ins_to_vdata(support_pos, candidate, i, label=1)
'''
## -- method B: use sort --
self._phase2_total = candidate['word'].size(0)
for i in range(min(len(candidate_prob), sort_num2)):
iid = pick_or_not[i][1]
if (pick_or_not[i][0] > sort_threshold2) and (candidate_prob[iid] > sort_ori_threshold) and not (candidate['id'][iid] in exist_id):
exist_id[candidate['id'][iid]] = 1
self._phase2_add_num += 1
self._add_ins_to_vdata(support_pos, candidate, iid, label=1)
## build new support set
support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
# support_rep = torch.cat([support_pos_rep, support_neg_rep], 0)
# support_label = torch.cat([support_pos['label'], support_neg['label']], 0)
## finetune
# print("Fine-tune Init")
self._train_finetune_init()
self._train_finetune(support_pos_rep)
if self.args.eval:
self._forward_eval_binary(query, threshold)
self._metric.append(np.array([self._f1, self._prec, self._recall]))
if self.args.print_debug:
print('\nphase2 add {} ins / {}'.format(self._phase2_add_num, self._phase2_total))
self._forward_eval_binary(query, threshold)
if self.args.print_debug:
print('\nphase2 add {} ins / {}'.format(self._phase2_add_num, self._phase2_total))
return support_pos_rep
def _forward_eval_binary(self, query, threshold=0.5):
'''
snowball process (eval)
query: query set (raw data)
threshold: ins with prob > threshold will be classified as positive
return (accuracy at threshold, precision at threshold, recall at threshold, f1 at threshold, auc),
'''
query_prob = self._infer(query, batch_size=self.args.infer_batch_size).cpu().detach().numpy()
label = query['label'].cpu().detach().numpy()
accuracy = float(np.logical_or(np.logical_and(query_prob > threshold, label == 1), np.logical_and(query_prob < threshold, label == 0)).sum()) / float(query_prob.shape[0])
if (query_prob > threshold).sum() == 0:
precision = 0
else:
precision = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((query_prob > threshold).sum())
recall = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((label == 1).sum())
if precision + recall == 0:
f1 = 0
else:
f1 = float(2.0 * precision * recall) / float(precision + recall)
auc = sklearn.metrics.roc_auc_score(label, query_prob)
if self.args.print_debug:
print('')
sys.stdout.write('[EVAL] acc: {0:2.2f}%, prec: {1:2.2f}%, rec: {2:2.2f}%, f1: {3:1.3f}, auc: {4:1.3f}'.format(\
accuracy * 100, precision * 100, recall * 100, f1, auc) + '\r')
sys.stdout.flush()
self._accuracy = accuracy
self._prec = precision
self._recall = recall
self._f1 = f1
return (accuracy, precision, recall, f1, auc)
def forward(self, support_pos, query, distant, pos_class, threshold=0.5, threshold_for_snowball=0.5):
'''
snowball process (train + eval)
support_pos: support set (positive, raw data)
support_neg: support set (negative, raw data)
query: query set (raw data)
distant: distant data loader
pos_class: positive relation (name)
threshold: ins with prob > threshold will be classified as positive
threshold_for_snowball: distant ins with prob > th_for_snowball will be added to extended support set
'''
self.pos_class = pos_class
self._forward_train(support_pos, query, distant, threshold=threshold)
def init_10shot(self, Ws, bs):
self.Ws = torch.stack(Ws, 0).transpose(0, 1) # (230, 16)
self.bs = torch.stack(bs, 0).transpose(0, 1) # (1, 16)
def eval_10shot(self, query):
x = self.sentence_encoder(query)
x = torch.matmul(x, self.Ws) + self.new_bias # (batch_size, 16)
x = torch.sigmoid(x)
_, pred = x.max(-1) # (batch_size)
return self.__accuracy__(pred, query['label'])
Reference
来源:CSDN
作者:NERV_Dyson
链接:https://blog.csdn.net/qq_38382642/article/details/104124148