学习论文《Neural Snowball for Few-Shot Relation Learning》笔记

狂风中的少年 提交于 2020-01-31 22:22:41

笔记在一定程度上参考了AI科技评论的用于少次关系学习的神经网络雪球机制

1.文章的创新点

关系增长的关系抽取

关系抽取(RelationExtraction)是自然语言处理当中的一个重要研究课题,其探究如何从文本中抽取结构化的关系事实。例如,从句子“比尔盖茨是微软的创始人”中,我们可以抽取出(比尔盖茨,创始人,微软)这样一个关系三元组,并用于知识图谱补全等下游任务中。
与关系抽取相关的工作有很多,但他们大多针对预定义的关系类型,即给定一个人为定义好的关系集合,在抽取时仅考虑集合内的关系类型。然而,我们面临着开放式的关系增长,随着新领域、新知识的出现,关系类型也在不断增多。因此,我们需要能够应对关系增长的关系抽取模型。

知识度量:通过RSN来在已有关系上的大规模数据上训练距离度量,迁移到新的关系类型上,利用新关系的少量数据作为启动种子,从大规模的无监督数据中挖掘有用信息,挖掘越多的有用信息,我们就能得到越好的关系抽取模型。

2. Related Work

2.1 Three different kinds of data

目前的研究当中,主要有以下几种关系抽取的场景,他们所针对的关系类型和利用的数据都有所不同:

  • 有监督的关系抽取(Supervised Relation Extraction):其针对预定义的关系集合,使用大规模的监督数据
  • 半监督的关系抽取(Semi-Supervised Relation Extraction):其针对的也是预定义的关系集合,希望使用相对较少的监督数据,在大量无监督数据的帮助下,能够取得与有监督关系抽取类似的效果。
  • 少次学习关系抽取(Few-Shot Relation Extraction):其针对的新的(没见过的)关系类型,通过在已有关系类型上的大规模数据预先训练,再快速迁移到新关系类型的少量数据上,达到少次学习的目的。
  • 自启动关系抽取(Bootstrapping Relation Extraction):其面向的也是开放的关系场景,对于新的关系类型,仅给定少量的启动样本,以迭代的方式从大规模的数据中挖掘更多的信息,从而得到更加强大的关系抽取模型。

从上面的分析中可以看出,这些方法涉及到了三种类型的数据:

  • 在已有关系类型上的大规模监督数据(large-scale existing relations data)
  • 对于新关系的少量标注数据(The new relation few-shot ins)
  • 以及大规模的无监督数据,从语料库中得到的unseen或unlabelled data(large-scale Unlabelled Corpora data)

我们希望能够充分的利用这三种数据,于是作者提出了一种新的关系抽取方法——Snowball方法

2.2 Nerual Snowball

三种数据类型的Nerual Snowball模型如下
在这里插入图片描述

如上图所示,Neural Snowball通过在已有关系上的大规模数据上训练距离度量,迁移到新的关系类型上,利用新关系的少量数据作为启动种子,从大规模的无监督数据中挖掘有用信息,挖掘越多的有用信息,我们就能得到越好的关系抽取模型。

3. Methodology

Nerual Snowball的构成如下
在这里插入图片描述

  • 种子集Seed: 对于一个新的关系类型,给定少量样本作为Seed集SrS_r,在每一轮迭代iteration后,种子集都会扩充经过选择的unlabelled ins(上文中的大规模无监督数据),得到的新的SrS_r再次参与迭代。
  • 候选集C1: C1中的instance是通过远程监督(distant supervision)挖掘得到的,如下图所示,可以看出如果我们新的关系是founder,那么对于种子集中的少量正确标注样例"Bill founder Microsoft",远程监督只会获得含有Bill 和 Microsoft的实体对所包含的句子,因此容易引入并不含有founder这层关系的句子(Bill mentioned Microsoft)。这些不能体现新的关系"founder"的句子是我们不需要的,所以C1称之为候选集:

  • RSN: 判断远程监督挖掘的信息是否是有用信息的过滤器
  • Relation Classifierg(x)g(x): 是我们模型最终想要得到的二分类器。从C1过滤后的instances与labelled instance有监督训练这个relation classifer。用这个分类器对candidate1做分类,得到candidate2。
  • C2: 因为从C1得到的数据并不是完全过滤干净的,也就是说二分类器Relation Classifier模型还不是很强,所以C2也是候选candidate的集合。因此需要再加一个fliter。经过再次过滤后就基本放心了通过二分类器得到的句子(Steven Jobs Apple),再将该句作为待选句子扩充入种子集参与下一轮迭代。

Neural Snowball 的整个流程如下:

  • 输入:一个新的关系类型,以及少量的标注数据(启动种子)
  • 目标:训练一个该关系类型的二分类器。用二分类器是因为这样更具可扩展性,当关系类型增加的时候,可以将多个二分类器放在一起使用。
  • 训练:以启动种子开始,迭代式的从无监督数据中挖掘有用信息。

每一轮迭代主要分为两个阶段:

(1) 利用远监督获取待选句子;
(2) 利用新的关系分类器获取待选句子。

远监督(Distant Supervision)是指,如果已有数据告诉我们,实体h、t之间有关系r,我们就找到所有包含h、t的句子,并假设他们真的表达了关系r。第一步获取了新的训练数据之后,Neural Snowball会训练新的关系分类器,这个新的分类器会从无监督数据中挖掘它认为属于关系r的数据,这些新数据可以帮助训练更好的分类器。

4.Neural Modules

Nerual Snowball有两个关键的components:RSN和Relation Classifier

4.1 RSN s(x,y)s(x,y)

  • 输入是两个instance,比如上图通过远程监督获得的两句话,其中的instance:“Bill founder Microsoft” 和 “Bill mentioned Microsoft”(显然这两个Instance并不是表示同一种relation的)
  • 输出是0或1
Structure of RSN

RSN由两个encoderfsf_s和一个distant function组成,结构如下图,其输入两个句子,输出这两个句子是否表达的是同一种关系。我们在已有关系的大规模数据上预先训练RSN,并将它用在Neural Snowball中,对所有从无监督数据中选出来的候选数据,用RSN将它们与启动种子进行比较,仅留下置信度较高的样本。

在这里插入图片描述

RSN中的encoder:输入instance,输出它们的representation vectors(表示向量)。其中这两个encoder是权值共享(parameter sharing)的,也就是如果用CNN作编码器,那么卷积核就是权值共享,加快算力。(权值共享:CNN中的权值共享理解

RSN中的distance function就是用来计算similarity的:s(x,y)=σ(wsT(fs(x)fs(y))2+bs)s(x,y)=\sigma (\mathcal{w}_s^T(f_s(x)-f_s(y))^2+b_s)
可被当作加权的L2 distance范数,w和b式训练得到的,越高的输出得分表明关系越近。

4.2 Relation Classifier g(x)g(x)

RC中有neural encoderff,用来将新的关系类型xx转化成real-valued vector;一个线性层可以得到输入的instance是属于relation的概率g(x):g(x):g(x)=σ(wTf(x)+b)g(x)=\sigma (w^Tf(x)+b)
如果需要分类多个关系,可以使用N个RC二分类器,因为新的关系类型(raw data)不断增加所以作者不使用N-way classifier。

4.3 Pre-training and Fine-tuning

Pre-training

预训练就是训RSN和RC的网络模型,在以后的迭代过程中这些参数就不改变了。
我们通过已标注的数据集(existing labeled datasetSNS_N)对其进行有监督训练。

对于RSN,先采样SNS_N(没错SNS_N就是上面那个large-scale corpus,最右边的那一种数据)中具有相同关系类型或不同关系类型的instance pairs(均可),接着train RSN with cross entropy loss.

对于RC,通过从SrS_r(seed set)中采样minibatch个正样例,从SNS_N中采样minibatch个负样例从而对RC中的linear layer 参数W和b训练优化(估计是Fully-connected layer),loss表达式:LSb,Tb(gw,b)=xSbloggw,b(x)+μxTblog(1gw,b)L_{\mathcal{S}_b,\mathcal{T}_b}(g_{w,b})=\sum_{x\in{\mathcal{S_b}}}log g_{w,b}(x)+\mu \sum_{x\in{\mathcal{T_b}}}log (1-g_{w,b})

Fine-tuning

因为预训练就是训RSN和RC的网络模型,在以后的迭代过程中这些参数就不改变了。所以微调只针对baseline-model进行微调,禁止套娃不展开。

5.瞅瞅代码

RSN的训练:encode是RSN中的编码函数,forward_infer是采用文中的方法进行计算score,另外forward_infer_sort是对应Snowball类下phase1部分的method B方法。但是无论是那种forward,他们都是集成在forward方法下。具体可以看一下注释

class Siamese(nn.Module):

    def __init__(self, sentence_encoder, hidden_size=230, drop_rate=0.5, pre_rep=None, euc=True):
        nn.Module.__init__(self)
        self.sentence_encoder = sentence_encoder # Should be different from main sentence encoder !!!
        self.hidden_size = hidden_size
        # self.fc1 = nn.Linear(hidden_size * 2, hidden_size * 2)
        # self.fc2 = nn.Linear(hidden_size * 2, 1)
        self.fc = nn.Linear(hidden_size, 1)
        self.cost = nn.BCELoss(reduction="none")
        self.drop = nn.Dropout(drop_rate)
        self._accuracy = 0.0
        self.pre_rep = pre_rep
        self.euc = euc

    def forward(self, data, num_size, num_class, threshold=0.5):
        # view : 将x处理成num_class行,每行元素为num_size行1列
        x = self.sentence_encoder(data).contiguous().view(num_class, num_size, -1)
        # view: x1,x2,y1,y2处理成1行,每行(也就是该行)有hidden_size个元素(hidden_size列),//是整除(防小数)
        # 其中  [:, :num_size//2]是遍历所有行(class),取到每行的第num_size//2个元素之前的所有元素
        #       [:, num_size//2:]是遍历所有行,取到每行的第num_size//2个元素之后的所有元素,相当于x1和x2分开了num_size
        #       [:num_class//2,:]是遍历所有列,取到每列的第num_size//2个元素之前的所有元素,相当于y1和y2分开了所有的num_class
        x1 = x[:, :num_size//2].contiguous().view(-1, self.hidden_size)
        x2 = x[:, num_size//2:].contiguous().view(-1, self.hidden_size)
        y1 = x[:num_class//2,:].contiguous().view(-1, self.hidden_size)
        y2 = x[num_class//2:,:].contiguous().view(-1, self.hidden_size)
        # y1 = x[0].contiguous().unsqueeze(0).expand(x.size(0) - 1, -1, -1).contiguous().view(-1, self.hidden_size)
        # y2 = x[1:].contiguous().view(-1, self.hidden_size)

        label = torch.zeros((x1.size(0) + y1.size(0))).long().cuda()
        label[:x1.size(0)] = 1  #x1的label全标签为1
        z1 = torch.cat([x1, y1], 0)
        z2 = torch.cat([x2, y2], 0)

        if self.euc:
            dis = torch.pow(z1 - z2, 2)
            dis = self.drop(dis)
            score = torch.sigmoid(self.fc(dis).squeeze())
        else:
            z = z1 * z2
            z = self.drop(z)
            z = self.fc(z).squeeze()
            # z = torch.cat([z1, z2], -1)
            # z = F.relu(self.fc1(z))
            # z = self.fc2(z).squeeze()
            score = torch.sigmoid(z)

        self._loss = self.cost(score, label.float()).mean()
        pred = torch.zeros((score.size(0))).long().cuda()
        pred[score > threshold] = 1
        self._accuracy = torch.mean((pred == label).type(torch.FloatTensor))
        pred = pred.cpu().detach().numpy()
        label = label.cpu().detach().numpy()
        self._prec = float(np.logical_and(pred == 1, label == 1).sum()) / float((pred == 1).sum() + 1)
        self._recall = float(np.logical_and(pred == 1, label == 1).sum()) / float((label == 1).sum() + 1)

    def encode(self, dataset, batch_size=0): 
        if self.pre_rep is not None:
            return self.pre_rep[dataset['id'].view(-1)] 

        if batch_size == 0:
            x = self.sentence_encoder(dataset)
        else:
            total_length = dataset['word'].size(0)
            max_iter = total_length // batch_size
            if total_length % batch_size != 0:
                max_iter += 1
            x = []
            for it in range(max_iter):
                scope = list(range(batch_size * it, min(batch_size * (it + 1), total_length)))
                with torch.no_grad():
                    _ = {'word': dataset['word'][scope], 'mask': dataset['mask'][scope]}
                    if 'pos1' in dataset:
                        _['pos1'] = dataset['pos1'][scope]
                        _['pos2'] = dataset['pos2'][scope]
                    _x = self.sentence_encoder(_)
                x.append(_x.detach())
            x = torch.cat(x, 0) #concatenate
        return x

    #使用method A,有阈值
    def forward_infer(self, x, y, threshold=0.5, batch_size=0):
        x = self.encode(x, batch_size=batch_size)
        support_size = x.size(0)
        y = self.encode(y, batch_size=batch_size)
        # a.squeeze(N) 就是在a中指定位置N加上一个维数为1的维度
        x = x.unsqueeze(1)  #N = 1
        y = y.unsqueeze(0)  #N = 0

        if self.euc:
            dis = torch.pow(x - y, 2)   # L2 distance
            score = torch.sigmoid(self.fc(dis).squeeze(-1)).mean(0)    #nn.linear自有权重和bias,得到score
        else:
            z = x * y
            z = self.fc(z).squeeze(-1)
            score = torch.sigmoid(z).mean(0)

        pred = torch.zeros((score.size(0))).long().cuda()
        pred[score > threshold] = 1
        pred = pred.view(support_size, -1).sum(0)
        pred[pred < 1] = 0
        pred[pred > 0] = 1
        return pred

    # 使用sort方法
    def forward_infer_sort(self, x, y, batch_size=0):
        x = self.encode(x, batch_size=batch_size)
        support_size = x.size(0)
        y = self.encode(y, batch_size=batch_size)
        x = x.unsqueeze(1)
        y = y.unsqueeze(0)

        if self.euc:
            dis = torch.pow(x - y, 2)
            score = torch.sigmoid(self.fc(dis).squeeze(-1)).mean(0)
        else:
            z = x * y
            z = self.fc(z).squeeze(-1)
            score = torch.sigmoid(z).mean(0)

        pred = []
        for i in range(score.size(0)):
            pred.append((score[i], i))
        pred.sort(key=lambda x: x[0], reverse=True)
        return pred

Snowball部分的代码,很多是用来读取数据的部分,重点在_forward_train下面开始了第三部分从input,phase1,phase2的操作。另外代码原作者也详细注释了relation classifier的负采样过程。

class Snowball(nrekit.framework.Model):
    
    def __init__(self, sentence_encoder, base_class, siamese_model, hidden_size=230, drop_rate=0.5, weight_table=None, pre_rep=None, neg_loader=None, args=None):
        nrekit.framework.Model.__init__(self, sentence_encoder)
        self.hidden_size = hidden_size
        self.base_class = base_class
        self.fc = nn.Linear(hidden_size, base_class)
        self.drop = nn.Dropout(drop_rate)
        self.siamese_model = siamese_model
        # self.cost = nn.BCEWithLogitsLoss()
        self.cost = nn.BCELoss(reduction="none")
        # self.cost = nn.CrossEntropyLoss()
        self.weight_table = weight_table
        
        self.args = args

        self.pre_rep = pre_rep
        self.neg_loader = neg_loader

    # def __loss__(self, logits, label):
    #     onehot_label = torch.zeros(logits.size()).cuda()
    #     onehot_label.scatter_(1, label.view(-1, 1), 1)
    #     return self.cost(logits, onehot_label)

    # def __loss__(self, logits, label):
    #     return self.cost(logits, label)

    def forward_base(self, data):
        batch_size = data['word'].size(0)
        x = self.sentence_encoder(data) # (batch_size, hidden_size)
        x = self.drop(x)
        x = self.fc(x) # (batch_size, base_class)

        x = torch.sigmoid(x)
        if self.weight_table is None:
            weight = 1.0
        else:
            weight = self.weight_table[data['label']].unsqueeze(1).expand(-1, self.base_class).contiguous().view(-1)
        label = torch.zeros((batch_size, self.base_class)).cuda()
        label.scatter_(1, data['label'].view(-1, 1), 1) # (batch_size, base_class)
        loss_array = self.__loss__(x, label)
        self._loss = ((label.view(-1) + 1.0 / self.base_class) * weight * loss_array).mean() * self.base_class
        # self._loss = self.__loss__(x, data['label'])
        
        _, pred = x.max(-1)
        self._accuracy = self.__accuracy__(pred, data['label'])
        self._pred = pred
    
    def forward_baseline(self, support_pos, query, threshold=0.5):
        '''
        baseline model
        support_pos: positive support set
        support_neg: negative support set
        query: query set
        threshold: ins whose prob > threshold are predicted as positive
        '''
        
        # train
        self._train_finetune_init()
        # support_rep = self.encode(support, self.args.infer_batch_size)
        support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
        # self._train_finetune(support_rep, support['label'])
        self._train_finetune(support_pos_rep)

        
        # test
        query_prob = self._infer(query, batch_size=self.args.infer_batch_size).cpu().detach().numpy()
        label = query['label'].cpu().detach().numpy()
        self._baseline_accuracy = float(np.logical_or(np.logical_and(query_prob > threshold, label == 1), np.logical_and(query_prob < threshold, label == 0)).sum()) / float(query_prob.shape[0])
        if (query_prob > threshold).sum() == 0:
            self._baseline_prec = 0
        else:        
            self._baseline_prec = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((query_prob > threshold).sum())
        self._baseline_recall = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((label == 1).sum())
        if self._baseline_prec + self._baseline_recall == 0:
            self._baseline_f1 = 0
        else:
            self._baseline_f1 = float(2.0 * self._baseline_prec * self._baseline_recall) / float(self._baseline_prec + self._baseline_recall)
        self._baseline_auc = sklearn.metrics.roc_auc_score(label, query_prob)
        if self.args.print_debug:
            print('')
            sys.stdout.write('[BASELINE EVAL] acc: {0:2.2f}%, prec: {1:2.2f}%, rec: {2:2.2f}%, f1: {3:1.3f}, auc: {4:1.3f}'.format( \
                self._baseline_accuracy * 100, self._baseline_prec * 100, self._baseline_recall * 100, self._baseline_f1, self._baseline_auc))
            print('')

    def __dist__(self, x, y, dim):
        return (torch.pow(x - y, 2)).sum(dim)

    def __batch_dist__(self, S, Q):
        return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3)

    def forward_few_shot_baseline(self, support, query, label, B, N, K, Q):
        support_rep = self.encode(support, self.args.infer_batch_size)
        query_rep = self.encode(query, self.args.infer_batch_size)
        support_rep.view(B, N, K, -1)
        query_rep.view(B, N * Q, -1)
        
        NQ = N * Q
         
        # Prototypical Networks 
        proto = torch.mean(support_rep, 2) # Calculate prototype for each class
        logits = -self.__batch_dist__(proto, query)
        _, pred = torch.max(logits.view(-1, N), 1)

        self._accuracy = self.__accuracy__(pred.view(-1), label.view(-1))

        return logits, pred

#    def forward_few_shot(self, support, query, label, B, N, K, Q):
#        for b in range(B):
#            for n in range(N):
#                _forward_train(self, support_pos, None, query, distant, threshold=0.5):
#
#        '''
#        support_rep = self.encode(support, self.args.infer_batch_size)
#        query_rep = self.encode(query, self.args.infer_batch_size)
#        support_rep.view(B, N, K, -1)
#        query_rep.view(B, N * Q, -1)
#        '''
#        
#        proto = []
#        for b in range(B):
#            for N in range(N)
#        
#        NQ = N * Q
#         
#        # Prototypical Networks 
#        proto = torch.mean(support_rep, 2) # Calculate prototype for each class
#        logits = -self.__batch_dist__(proto, query)
#        _, pred = torch.max(logits.view(-1, N), 1)
#
#        self._accuracy = self.__accuracy__(pred.view(-1), label.view(-1))
#
#        return logits, pred

    def _train_finetune_init(self):
        # init variables and optimizer
        self.new_W = Variable(self.fc.weight.mean(0) / 1e3, requires_grad=True)
        self.new_bias = Variable(torch.zeros((1)), requires_grad=True)
        self.optimizer = optim.Adam([self.new_W, self.new_bias], self.args.finetune_lr, weight_decay=self.args.finetune_wd)
        self.new_W = self.new_W.cuda()
        self.new_bias = self.new_bias.cuda()

    # 对relation classfier的训练
    def _train_finetune(self, data_repre, learning_rate=None, weight_decay=1e-5):
        '''
        train finetune classifier with given data
        data_repre: sentence representation (encoder's output)
        label: label
        '''
        
        self.train()

        optimizer = self.optimizer
        if learning_rate is not None:
            optimizer = optim.Adam([self.new_W, self.new_bias], learning_rate, weight_decay=weight_decay)

        # hyperparameters
        max_epoch = self.args.finetune_epoch
        batch_size = self.args.finetune_batch_size
        
        # dropout
        data_repre = self.drop(data_repre) 
        
        # train
        if self.args.print_debug:
            print('')
        for epoch in range(max_epoch):
            max_iter = data_repre.size(0) // batch_size
            if data_repre.size(0) % batch_size != 0:
                max_iter += 1
            order = list(range(data_repre.size(0)))
            random.shuffle(order)
            for i in range(max_iter):            
                x = data_repre[order[i * batch_size : min((i + 1) * batch_size, data_repre.size(0))]]
                # batch_label = label[order[i * batch_size : min((i + 1) * batch_size, data_repre.size(0))]]
                
                # neg sampling
                # ---------------------
                batch_label = torch.ones((x.size(0))).long().cuda()
                neg_size = int(x.size(0) * 1)
                neg = self.neg_loader.next_batch(neg_size)
                neg = self.encode(neg, self.args.infer_batch_size)
                x = torch.cat([x, neg], 0)
                batch_label = torch.cat([batch_label, torch.zeros((neg_size)).long().cuda()], 0)
                # ---------------------

                # Relation Classifier
                x = torch.matmul(x, self.new_W) + self.new_bias # (batch_size, 1)
                x = torch.sigmoid(x)

                # iter_loss = self.__loss__(x, batch_label.float()).mean()
                weight = torch.ones(batch_label.size(0)).float().cuda()
                weight[batch_label == 0] = self.args.finetune_weight #1 / float(max_epoch)
                iter_loss = (self.__loss__(x, batch_label.float()) * weight).mean()

                optimizer.zero_grad()
                iter_loss.backward(retain_graph=True)
                optimizer.step()
                if self.args.print_debug:
                    sys.stdout.write('[snowball finetune] epoch {0:4} iter {1:4} | loss: {2:2.6f}'.format(epoch, i, iter_loss) + '\r')
                    sys.stdout.flush()
        self.eval()

    def _add_ins_to_data(self, dataset_dst, dataset_src, ins_id, label=None):
        '''
        add one instance from dataset_src to dataset_dst (list)
        dataset_dst: destination dataset
        dataset_src: source dataset
        ins_id: id of the instance
        '''
        dataset_dst['word'].append(dataset_src['word'][ins_id])
        if 'pos1' in dataset_src:
            dataset_dst['pos1'].append(dataset_src['pos1'][ins_id])
            dataset_dst['pos2'].append(dataset_src['pos2'][ins_id])
        dataset_dst['mask'].append(dataset_src['mask'][ins_id])
        if 'id' in dataset_dst and 'id' in dataset_src:
            dataset_dst['id'].append(dataset_src['id'][ins_id])
        if 'entpair' in dataset_dst and 'entpair' in dataset_src:
            dataset_dst['entpair'].append(dataset_src['entpair'][ins_id])
        if 'label' in dataset_dst and label is not None:
            dataset_dst['label'].append(label)

    def _add_ins_to_vdata(self, dataset_dst, dataset_src, ins_id, label=None):
        '''
        add one instance from dataset_src to dataset_dst (variable)
        dataset_dst: destination dataset
        dataset_src: source dataset
        ins_id: id of the instance
        '''
        dataset_dst['word'] = torch.cat([dataset_dst['word'], dataset_src['word'][ins_id].unsqueeze(0)], 0)
        if 'pos1' in dataset_src:
            dataset_dst['pos1'] = torch.cat([dataset_dst['pos1'], dataset_src['pos1'][ins_id].unsqueeze(0)], 0)
            dataset_dst['pos2'] = torch.cat([dataset_dst['pos2'], dataset_src['pos2'][ins_id].unsqueeze(0)], 0)
        dataset_dst['mask'] = torch.cat([dataset_dst['mask'], dataset_src['mask'][ins_id].unsqueeze(0)], 0)
        if 'id' in dataset_dst and 'id' in dataset_src:
            dataset_dst['id'] = torch.cat([dataset_dst['id'], dataset_src['id'][ins_id].unsqueeze(0)], 0)
        if 'entpair' in dataset_dst and 'entpair' in dataset_src:
            dataset_dst['entpair'].append(dataset_src['entpair'][ins_id])
        if 'label' in dataset_dst and label is not None:
            dataset_dst['label'] = torch.cat([dataset_dst['label'], torch.ones((1)).long().cuda()], 0)

    def _dataset_stack_and_cuda(self, dataset):
        '''
        stack the dataset to torch.Tensor and use cuda mode
        dataset: target dataset
        '''
        if (len(dataset['word']) == 0):
            return
        dataset['word'] = torch.stack(dataset['word'], 0).cuda()
        if 'pos1' in dataset:
            dataset['pos1'] = torch.stack(dataset['pos1'], 0).cuda()
            dataset['pos2'] = torch.stack(dataset['pos2'], 0).cuda()
        dataset['mask'] = torch.stack(dataset['mask'], 0).cuda()
        dataset['id'] = torch.stack(dataset['id'], 0).cuda()

    def encode(self, dataset, batch_size=0):
        if self.pre_rep is not None:
            return self.pre_rep[dataset['id'].view(-1)]

        if batch_size == 0:
            x = self.sentence_encoder(dataset)
        else:
            total_length = dataset['word'].size(0)
            max_iter = total_length // batch_size
            if total_length % batch_size != 0:
                max_iter += 1
            x = []
            for it in range(max_iter):
                scope = list(range(batch_size * it, min(batch_size * (it + 1), total_length)))
                with torch.no_grad():
                    _ = {'word': dataset['word'][scope], 'mask': dataset['mask'][scope]}
                    if 'pos1' in dataset:
                        _['pos1'] = dataset['pos1'][scope]
                        _['pos2'] = dataset['pos2'][scope]
                    _x = self.sentence_encoder(_)
                x.append(_x.detach())
            x = torch.cat(x, 0)
        return x

    def _infer(self, dataset, batch_size=0):
        '''
        get prob output of the finetune network with the input dataset
        dataset: input dataset
        return: prob output of the finetune network
        '''
        x = self.encode(dataset, batch_size=batch_size) 
        x = torch.matmul(x, self.new_W) + self.new_bias # (batch_size, 1)
        x = torch.sigmoid(x)
        return x.view(-1)

    def _forward_train(self, support_pos, query, distant, threshold=0.5):
        '''
        snowball process (train)
        support_pos: support set (positive, raw data)
        support_neg: support set (negative, raw data)
        query: query set
        distant: distant data loader
        threshold: ins with prob > threshold will be classified as positive
        threshold_for_phase1: distant ins with prob > th_for_phase1 will be added to extended support set at phase1
        threshold_for_phase2: distant ins with prob > th_for_phase2 will be added to extended support set at phase2
        '''

        # hyperparameters
        snowball_max_iter = self.args.snowball_max_iter
        sys.stdout.flush()
        candidate_num_class = 20
        candidate_num_ins_per_class = 100
        
        sort_num1 = self.args.phase1_add_num
        sort_num2 = self.args.phase2_add_num
        sort_threshold1 = self.args.phase1_siamese_th
        sort_threshold2 = self.args.phase2_siamese_th
        sort_ori_threshold = self.args.phase2_cl_th

        # get neg representations with sentence encoder
        # support_neg_rep = self.encode(support_neg, batch_size=self.args.infer_batch_size)
        
        # init
        self._train_finetune_init()
        # support_rep = self.encode(support, self.args.infer_batch_size)

        # positive的raw data进行编码以representation
        support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
        # self._train_finetune(support_rep, support['label'])
        self._train_finetune(support_pos_rep)

        self._metric = []

        # copy
        original_support_pos = copy.deepcopy(support_pos)

        # snowball
        exist_id = {}
        if self.args.print_debug:
            print('\n-------------------------------------------------------')
        for snowball_iter in range(snowball_max_iter):
            if self.args.print_debug:
                print('###### snowball iter ' + str(snowball_iter))
            # phase 1: expand positive support set from distant dataset (with same entity pairs)

            ## get all entpairs and their ins in positive support set   ins is instance
            old_support_pos_label = support_pos['label'] + 0
            entpair_support = {}
            entpair_distant = {}
            for i in range(len(support_pos['id'])): # only positive support
                entpair = support_pos['entpair'][i] #实体对
                exist_id[support_pos['id'][i]] = 1
                if entpair not in entpair_support:
                    if 'pos1' in support_pos:
                        entpair_support[entpair] = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': []}
                    else:
                        entpair_support[entpair] = {'word': [], 'mask': [], 'id': []}
                self._add_ins_to_data(entpair_support[entpair], support_pos, i)
            
            ## pick all ins with the same entpairs in distant data and choose with siamese network
            self._phase1_add_num = 0 # total number of snowball instances
            self._phase1_total = 0
            for entpair in entpair_support:
                raw = distant.get_same_entpair_ins(entpair) # ins with the same entpair
                if raw is None:
                    continue
                if 'pos1' in support_pos:   #以字典储存
                    entpair_distant[entpair] = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': [], 'entpair': []}
                else:
                    entpair_distant[entpair] = {'word': [], 'mask': [], 'id': [], 'entpair': []}
                for i in range(raw['word'].size(0)):
                    if raw['id'][i] not in exist_id: # don't pick sentences already in the support set
                        self._add_ins_to_data(entpair_distant[entpair], raw, i)
                self._dataset_stack_and_cuda(entpair_support[entpair])
                self._dataset_stack_and_cuda(entpair_distant[entpair])
                if len(entpair_support[entpair]['word']) == 0 or len(entpair_distant[entpair]['word']) == 0:
                    continue

                # 比较entpair_support和entpair_distant中的实体对的相似度,决定候选集C1的取舍
                pick_or_not = self.siamese_model.forward_infer_sort(entpair_support[entpair], entpair_distant[entpair], batch_size=self.args.infer_batch_size)
                
                # pick_or_not = self.siamese_model.forward_infer_sort(original_support_pos, entpair_distant[entpair], threshold=threshold_for_phase1)
                # pick_or_not = self._infer(entpair_distant[entpair]) > threshold
      
                # -- method B: use sort --
                for i in range(min(len(pick_or_not), sort_num1)):
                    if pick_or_not[i][0] > sort_threshold1:
                        iid = pick_or_not[i][1]
                        self._add_ins_to_vdata(support_pos, entpair_distant[entpair], iid, label=1)
                        exist_id[entpair_distant[entpair]['id'][iid]] = 1
                        self._phase1_add_num += 1
                self._phase1_total += entpair_distant[entpair]['word'].size(0)
            '''
            if 'pos1' in support_pos:
                candidate = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': [], 'entpair': []}
            else:
                candidate = {'word': [], 'mask': [], 'id': [], 'entpair': []}

            self._phase1_add_num = 0 # total number of snowball instances
            self._phase1_total = 0
            for entpair in entpair_support:
                raw = distant.get_same_entpair_ins(entpair) # ins with the same entpair
                if raw is None:
                    continue
                for i in range(raw['word'].size(0)):
                    if raw['id'][i] not in exist_id: # don't pick sentences already in the support set
                        self._add_ins_to_data(candidate, raw, i)

            if len(candidate['word']) > 0:
                self._dataset_stack_and_cuda(candidate)
                pick_or_not = self.siamese_model.forward_infer_sort(support_pos, candidate, batch_size=self.args.infer_batch_size)
                    
                for i in range(min(len(pick_or_not), sort_num1)):
                    if pick_or_not[i][0] > sort_threshold1:
                        iid = pick_or_not[i][1]
                        self._add_ins_to_vdata(support_pos, candidate, iid, label=1)
                        exist_id[candidate['id'][iid]] = 1
                        self._phase1_add_num += 1
                self._phase1_total += candidate['word'].size(0)
            '''
            ## build new support set
            
            # print('---')
            # for i in range(len(support_pos['entpair'])):
            #     print(support_pos['entpair'][i])
            # print('---')
            # print('---')
            # for i in range(support_pos['id'].size(0)):
            #     print(support_pos['id'][i])
            # print('---')

            support_pos_rep = self.encode(support_pos, batch_size=self.args.infer_batch_size)
            # support_rep = torch.cat([support_pos_rep, support_neg_rep], 0)
            # support_label = torch.cat([support_pos['label'], support_neg['label']], 0)
            
            ## finetune
            # print("Fine-tune Init")
            self._train_finetune_init()
            self._train_finetune(support_pos_rep)
            if self.args.eval:
                self._forward_eval_binary(query, threshold)
            # self._metric.append(np.array([self._f1, self._prec, self._recall]))
            if self.args.print_debug:
                print('\nphase1 add {} ins / {}'.format(self._phase1_add_num, self._phase1_total))

            # phase 2: use the new classifier to pick more extended support ins
            self._phase2_add_num = 0
            candidate = distant.get_random_candidate(self.pos_class, candidate_num_class, candidate_num_ins_per_class)

            ## -- method 1: directly use the classifier --
            candidate_prob = self._infer(candidate, batch_size=self.args.infer_batch_size)
            ## -- method 2: use siamese network --

            pick_or_not = self.siamese_model.forward_infer_sort(support_pos, candidate, batch_size=self.args.infer_batch_size)

            ## -- method A: use threshold --
            '''
            self._phase2_total = candidate_prob.size(0)
            for i in range(candidate_prob.size(0)):
                # if (candidate_prob[i] > threshold_for_phase2) and not (candidate['id'][i] in exist_id):
                if (pick_or_not[i]) and (candidate_prob[i] > threshold_for_phase2) and not (candidate['id'][i] in exist_id):
                    exist_id[candidate['id'][i]] = 1 
                    self._phase2_add_num += 1
                    self._add_ins_to_vdata(support_pos, candidate, i, label=1)
            '''

            ## -- method B: use sort --
            self._phase2_total = candidate['word'].size(0)
            for i in range(min(len(candidate_prob), sort_num2)):
                iid = pick_or_not[i][1]
                if (pick_or_not[i][0] > sort_threshold2) and (candidate_prob[iid] > sort_ori_threshold) and not (candidate['id'][iid] in exist_id):
                    exist_id[candidate['id'][iid]] = 1 
                    self._phase2_add_num += 1
                    self._add_ins_to_vdata(support_pos, candidate, iid, label=1)

            ## build new support set
            support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
            # support_rep = torch.cat([support_pos_rep, support_neg_rep], 0)
            # support_label = torch.cat([support_pos['label'], support_neg['label']], 0)

            ## finetune
            # print("Fine-tune Init")
            self._train_finetune_init()
            self._train_finetune(support_pos_rep)
            if self.args.eval:
                self._forward_eval_binary(query, threshold)
                self._metric.append(np.array([self._f1, self._prec, self._recall]))
                if self.args.print_debug:
                    print('\nphase2 add {} ins / {}'.format(self._phase2_add_num, self._phase2_total))

        self._forward_eval_binary(query, threshold)
        if self.args.print_debug:
            print('\nphase2 add {} ins / {}'.format(self._phase2_add_num, self._phase2_total))

        return support_pos_rep

    def _forward_eval_binary(self, query, threshold=0.5):
        '''
        snowball process (eval)
        query: query set (raw data)
        threshold: ins with prob > threshold will be classified as positive
        return (accuracy at threshold, precision at threshold, recall at threshold, f1 at threshold, auc), 
        '''
        query_prob = self._infer(query, batch_size=self.args.infer_batch_size).cpu().detach().numpy()
        label = query['label'].cpu().detach().numpy()
        accuracy = float(np.logical_or(np.logical_and(query_prob > threshold, label == 1), np.logical_and(query_prob < threshold, label == 0)).sum()) / float(query_prob.shape[0])
        if (query_prob > threshold).sum() == 0:
            precision = 0
        else:
            precision = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((query_prob > threshold).sum())
        recall = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((label == 1).sum())
        if precision + recall == 0:
            f1 = 0
        else:
            f1 = float(2.0 * precision * recall) / float(precision + recall)
        auc = sklearn.metrics.roc_auc_score(label, query_prob)
        if self.args.print_debug:
            print('')
            sys.stdout.write('[EVAL] acc: {0:2.2f}%, prec: {1:2.2f}%, rec: {2:2.2f}%, f1: {3:1.3f}, auc: {4:1.3f}'.format(\
                    accuracy * 100, precision * 100, recall * 100, f1, auc) + '\r')
            sys.stdout.flush()
        self._accuracy = accuracy
        self._prec = precision
        self._recall = recall
        self._f1 = f1
        return (accuracy, precision, recall, f1, auc)

    def forward(self, support_pos, query, distant, pos_class, threshold=0.5, threshold_for_snowball=0.5):
        '''
        snowball process (train + eval)
        support_pos: support set (positive, raw data)
        support_neg: support set (negative, raw data)
        query: query set (raw data)
        distant: distant data loader
        pos_class: positive relation (name)
        threshold: ins with prob > threshold will be classified as positive
        threshold_for_snowball: distant ins with prob > th_for_snowball will be added to extended support set
        '''
        self.pos_class = pos_class 

        self._forward_train(support_pos, query, distant, threshold=threshold)

    def init_10shot(self, Ws, bs):
        self.Ws = torch.stack(Ws, 0).transpose(0, 1) # (230, 16)
        self.bs = torch.stack(bs, 0).transpose(0, 1) # (1, 16)

    def eval_10shot(self, query):
        x = self.sentence_encoder(query)
        x = torch.matmul(x, self.Ws) + self.new_bias # (batch_size, 16)
        x = torch.sigmoid(x)
        _, pred = x.max(-1) # (batch_size)
        return self.__accuracy__(pred, query['label'])

Reference

AAAI 2020 |清华大学:用于少次关系学习的神经网络雪球机制

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!