Cognitive Graph for Multi-Hop Reading Comprehension at Scale(ACL2019) 阅读笔记与源码解析

与世无争的帅哥 提交于 2019-12-04 01:35:31

论文地址为:Cognitive Graph for Multi-Hop Reading Comprehension at Scale

github地址:CogQA

背景

假设你手边有一个维基百科的搜索引擎,可以用来获取实体对应的文本段落,那么如何来回答下面这个复杂的问题呢? 

“谁是某部在2003年取景于洛杉矶Quality cafe的电影的导演?”

很自然地,我们将会从例如Quality cafe这样的“相关实体”入手,通过维基百科查询相关介绍,并在其中讲到好莱坞电影的时候迅速定位到“Old School”“Gone in 60 Seconds”这两部电影,通过继续查询两部电影相关的介绍,我们找到他们的导演。最后一步是判断到底是哪位导演,这需要我们自己分析句子的语意和限定词,在了解到电影是2003年之后,我们可以做出最后判断——Todd Phillips是我们想要的答案。

事实上,“快速将注意力定位到相关实体”和“分析句子语意进行推断”是两种不同的思维过程。

在认知学里,著名的“双过程理论(dual process theory)”认为,人的认知分为两个系统,系统一(System 1)是基于直觉的、无知觉的思考系统,其运作依赖于经验和关联;而系统二(System 2)则是人类特有的逻辑推理能力,此系统利用工作记忆(working memory)中的知识进行慢速但是可靠的逻辑推理,系统二是显式的,需要意识控制的,是人类高级智能的体现。

论文详情

因此,本文提出一种新颖的迭代框架:算法使用两个系统来维护一张认知图谱(Cognitive Graph):

  • 系统一在文本中抽取与问题相关的实体名称并扩展节点和汇总语义向量,
  • 系统二利用图神经网络在认知图谱上进行推理计算。

正如之前提到的,人类的系统一是无知觉(unconscious),CogQA中的系统一也是流行的NLP黑盒模型,例如BERT。

在文章的实现中,系统一的输入分为三部分:

  1. 问题本身
  2. 从前面段落中找到的“线索(clues)”
  3. 关于某个实体x的维基百科文档

系统一的目标是抽取文档中的“下一跳实体名称(hop span)”和“答案候选(ans span)”。

这些抽取的到的实体和答案候选将作为节点添加到认知图谱中。此外,系统一还将计算当前实体 x 的语意向量,这将在系统二中用作关系推理的初始值。

模型架构图如下:

 

 

 源码解析(主要是model.py文件)分为七大模块

1. 导入相应的库代码,主要是bert模块,有些库model.py没有用到,这块就不做相应解释了。(utils是文章作者写的模块)

 1 from pytorch_pretrained_bert.modeling import (
 2     BertPreTrainedModel as PreTrainedBertModel,
 3     BertModel,
 4     BertLayerNorm,
 5     gelu,
 6     BertEncoder,
 7     BertPooler,
 8 )
 9 import torch
10 from torch import nn
11 import re
12 import pdb
13 from pytorch_pretrained_bert.tokenization import (
14     whitespace_tokenize,
15     BasicTokenizer,
16     BertTokenizer,
17 )
18 from utils import (
19     fuzzy_find,
20     find_start_end_after_tokenized,
21     find_start_end_before_tokenized,
22     bundle_part_to_batch,
23 )

2. MLP模块

该模块较为简单,就是简单的多层感知机,如果大于两层,会加入相应的dropout 和 LayerNorm,并采用了bert所特有的gelu激活函数。

 1 class MLP(nn.Module):
 2     def __init__(self, input_sizes, dropout_prob=0.2, bias=False):
 3         super(MLP, self).__init__()
 4         self.layers = nn.ModuleList()
 5         for i in range(1, len(input_sizes)):
 6             self.layers.append(nn.Linear(input_sizes[i-1], input_sizes[i], bias=bias))
 7         self.norm_layers = nn.ModuleList()
 8         if len(input_sizes) > 2:
 9             for i in range(1, len(input_sizes) - 1):
10                 self.norm_layers.append(nn.LayerNorm(input_sizes[i]))
11         self.drop_out = nn.Dropout(p=dropout_prob)
12 
13     def forward(self, x):
14         for i, layer in enumerate(self.layers):
15             x = layer(self.drop_out(x))
16             if i < len(self.layers) - 1:
17                 x = gelu(x)
18                 if len(self.norm_layers):
19                     x = self.norm_layers[i](x)
20         return x

3. GCN模块

这里采用最基础的GCN,没有使用任何GCN库,速度可能较慢,但是考虑到主要时间限制在bert模型,所以这里的时间效率下降可忽略。

 1 class GCN(nn.Module):
 2     def init_weights(self, module):
 3         if isinstance(module, (nn.Linear, nn.Embedding)):
 4             module.weight.data.normal_(mean=0.0, std=0.05)
 5 
 6     def __init__(self, input_size):
 7         super(GCN, self).__init__()
 8         self.diffusion = nn.Linear(input_size, input_size, bias=False)  # diffusion线性变换 
 9         self.retained = nn.Linear(input_size, input_size, bias=False)   # retaine线性变换
10         self.predict = MLP(input_sizes=(input_size, input_size, 1))
11         self.apply(self.init_weights)   #  参数矩阵赋予初始化权重(正态分布)
12 
13     def forward(self, A, x):
14         layer1_diffusion = A.t().mm(gelu(self.diffusion(x)))   # t() 转置
15         # A为邻接矩阵(n, n) *  (n, input_size)   ==> (n, input_size)
16         x = gelu(self.retained(x) + layer1_diffusion)   # (n, input_size)
17         layer2_diffusion = A.t().mm(gelu(self.diffusion(x)))   # (n, input_size)
18         x = gelu(self.retained(x) +layer2_diffusion)    # (n, input_size)
19         return self.predict(x).sqeeze(-1)   # (n, )

4. bert embedding模块 (具体见注释)

 1 class BertEmbeddingsPlus(nn.Module):
 2     """ 构建word embeddings, position embeddings, token_type embeddings.
 3     """
 4     def __init__(self, config, max_sentence_type=30):
 5         super(BertEmbeddingsPlus, self).__init__()
 6         self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
 7         self.position_embeddings = nn.Embedding(
 8             config.max_position_embeddings, config.hidden_size
 9         )   # shape (位置embedding的种类,隐层大小)
10         self.token_type_embeddings = nn.Embedding(
11             config.type_vocab_size, config.hidden_size
12         )   # (2, hidden_size)  A/B segment
13         self.sentence_type_embeddings = nn.Embedding(
14             max_sentence_type, config.hidden_size
15         )   # 句子类型embedding  (30, hidden_size)
16         self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)  # bert LN层
17         self.dropout = nn.Dropout(config.hidden_dropout_prob)
18 
19     def forward(self, input_ids, token_type_ids=None):
20         """
21         :param input_ids: (n, seq_length)   n 就是 batch_size
22         :param token_type_ids: (n, seq_length)
23         :return:
24         """
25         seq_length = input_ids.size(1)  # 文本序列长度
26         position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
27         # [5] => [0,1,2,3,4]
28         position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
29         # shape变化:(seq_length) => (1, seq_length)  => (n, seq_length)
30         if token_type_ids is None:
31             token_type_ids = torch.zeros_like(input_ids)
32 
33         word_embeddings = self.word_embeddings(input_ids)
34         position_embeddings = self.position_embeddings(position_ids)
35         token_type_embeddings = self.token_type_embeddings((token_type_ids > 0).long())
36         #  token_type_embeddings, 分为 A/B,segment bert输入模式
37         sentences_type_embeddings = self.sentence_type_embeddings(token_type_ids)
38         #  这才是对token_type进行embedding
39 
40         embeddings = (word_embeddings + position_embeddings
41                       + token_type_embeddings + sentences_type_embeddings)  
42         # 四个embedding相加,充分考虑各种信息
43         embeddings = self.LayerNorm(embeddings)
44         embeddings = self.dropout(embeddings)
45         return embeddings

5 bert模型编码模块

输入input_ids,输出bert的最后的编码结果,和指定哪一层的编码结果(下层偏语法,上层偏语义)
 1 class BertModelPlus(BertModel):
 2     def __init__(self, config):
 3         super(BertModelPlus, self).__init__()
 4         self.embeddings = BertEmbeddingsPlus(config)
 5         self.encoder = BertEncoder(config)   # bert 编码器
 6         self.pooler = BertPooler(config)     # bert 池化器
 7         self.apply(self.init_bert_weights)   # BertModel 的初始权重参数
 8 
 9     def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_hidden=-4):
10         if attention_mask is None:
11             attention_mask = torch.ones_like(input_ids)  # (n, seq_length), n is batch_size
12         if token_type_ids is None:
13             token_type_ids = torch.zeros_like(input_ids)
14 
15         extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
16         # (n, seq_length) => (n, 1, 1, seq_length)
17 
18         extended_attention_mask = extended_attention_mask.to(
19             dtype=next(self.parameters()).dtype
20         )   # fp16   转换数值类型
21         extended_attention_mask = (1.0 - extended_attention_mask) * (-10000.0)
22         # 将attention_mask为0的变为-10000, 1变为0, 方便softmax求注意力后,
23         # mask为0(也就是padding)的值完全消除(=0)
24         # 例如[1, 1, 1, 0, 0]  ==> (0, 0, 0, -10000, -10000)
25 
26         embedding_output = self.embeddings(input_ids, token_type_ids)
27         encoded_layers = self.encoder(
28             embedding_output, extended_attention_mask, output_all_encoded_layers=True)
29         # sequence_output = encoded_layers[-1]
30         # pooled_output = self.pooler(sequence_output)
31         encoded_layers, hidden_layers = (
32             encoded_layers[-1], encoded_layers[output_hidden]  # -4 倒数第四层
33         )
34         return encoded_layers, hidden_layers   # shape (batch_size, hidden_size), (batch_size, hidden_size)

6. 多跳阅读理解模块

也就是论文中提到的系统一,即“下一跳实体名称”和“答案候选”的抽取,是通过预测每个位置是否是span开始或者结束的概率来确定,与BERT原文中的做法相同;

其中几个值得注意的细节,比如之所以将“下一跳实体名称”和“答案候选”分开,是因为前者更多关注语意相关性而后者则需要匹配疑问词;

第0个位置的输出被用来产生一个阈值,判断段落内是否包含有意义的“下一跳实体名称”或者“答案候选”。

该模块就是常规的抽取式阅读理解,解码出span的代码部分较为复杂~

注:这里没有用到GCN,GCN是用来作为系统二进行推理的模块。

  1 class BertForMultiHopQuestionAnswering(PreTrainedBertModel):
  2     def __init__(self, config):
  3         super(BertForMultiHopQuestionAnswering, self).__init__()
  4         self.bert = BertModelPlus(config)
  5         self.qa_outputs = nn.Linear(config.hidden_size, 4)
  6         self.apply(self.init_bert_weights)   # PreTrainedBertModel 初始化权重
  7 
  8     def forward(self, input_ids,
  9                 token_type_ids=None,
 10                 attention_mask=None,
 11                 sep_positions=None,
 12                 hop_start_weights=None,
 13                 hop_end_weights=None,
 14                 ans_start_weights=None,
 15                 ans_end_weights=None,
 16                 B_starts=None,
 17                 allow_limit=(0, 0),
 18                 ):
 19         """
 20         从系统1抽取span  (分为两个系统,具体看原文)
 21         :param input_ids:  LongTensor
 22             (batch_size, max_len)
 23         :param token_type_ids:  LongTensor
 24             The A/B Segmentation in BERTs. (batch, maxlen)
 25         :param attention_mask:  LongTensor
 26             指示该位置是token还是padding      (batch_size, maxlen)
 27         :param sep_positions:    LongTensor
 28             [SEP]的具体位置 主要用来发现支持段落的句子  (batch_size, max_seps)
 29         :param hop_start_weights:  Tensor(默认为FloatTensor)
 30             hop开始位置的标注情况
 31         :param hop_end_weights:    Tensor
 32             hop结束位置的标注情况  (ground truth)
 33         :param ans_start_weights:   Tensor
 34             答案标注开始位置的可能性(概率)
 35         :param ans_end_weights:     Tensor
 36             答案标注结束位置的可能性(概率)
 37         :param B_starts:
 38             句子B的开始位置
 39         :param allow_limit:
 40             An Offset for negative threshold (负阈值的偏移量)
 41         :return:
 42         """
 43         batch_size = input_ids.size()[0]
 44         device = input_ids.get_device() if input_ids.is_cuda else torch.device('cpu')
 45         sequence_output, hidden_output = self.bert(input_ids, token_type_ids, attention_mask)
 46         # 上面两者的shape都为: (batch_size, max_len, hidden_size)
 47         semantics = hidden_output[:, 0]   # shape: (batch_size, hidden_size)
 48 
 49         if sep_positions is None:
 50             return semantics   # 仅仅语义信息
 51         else:
 52             max_sep = sep_positions.size()[-1]    # max_seps
 53         if max_sep == 0:
 54             empty = torch.zeros(batch_size, 0, dtype=torch.long, device=device)  # mistake
 55             return (
 56                 empty,
 57                 empty,
 58                 semantics,
 59                 empty,
 60             )  # Only semantics, used in eval, the same ``empty'' variable is a mistake in general cases but simple
 61 
 62         # 预测span
 63         logits = self.qa_outputs(sequence_output)
 64         hop_start_logits, hop_end_logits, ans_start_logits, ans_end_logits = logits.split(
 65             split_size=1, dim=-1   # 前面的1代表单个分块的形状大小
 66         )  # 每个的形状为 (batch_size, max_len, 1)
 67         hop_start_logits = hop_start_logits.squeeze(-1)
 68         hop_end_logits = hop_end_logits.squeeze(-1)
 69         ans_start_logits = ans_start_logits.squeeze(-1)
 70         ans_end_logits = ans_end_logits.squeeze(-1)  # Shape: [batch_size, max_len]
 71 
 72         if hop_start_weights is not None:   # train mode (因为提供了标签信息:hop_start_weights等)
 73             lgsf = nn.LogSoftmax(dim=1)
 74             # 如果句子中没有目标span,start_weights = end_weights = 0(tensor)
 75             # 以下四个求二元交叉熵loss
 76             hop_start_loss = -torch.sum(hop_start_weights * lgsf(hop_start_logits), dim=-1)
 77             hop_end_loss = -torch.sum(hop_end_weights * lgsf(hop_end_logits), dim=-1)
 78             ans_start_loss = -torch.sum(ans_start_weights * lgsf(ans_start_logits), dim=-1)
 79             ans_end_loss = -torch.sum(ans_end_weights * lgsf(ans_end_logits), dim=-1)
 80 
 81             hop_loss = torch.mean((hop_start_loss + hop_end_loss)) / 2
 82             ans_loss = torch.mean((ans_start_loss + ans_end_loss)) / 2
 83 
 84         else:
 85             K_hop, K_ans = 3, 1
 86             hop_preds = torch.zeros(batch_size, K_hop, 3, dtype=torch.long, device=device)
 87             # (batch_size, 3, 3)
 88             ans_preds = torch.zeros(batch_size, K_ans, 3, dtype=torch.long, device=device)
 89             # (batch_size, 1, 3)
 90 
 91             ans_start_gap = torch.zeros(batch_size, device=device)
 92             for u, (start_logits, end_logits, preds, K, allow) in enumerate(
 93                 (
 94                     (
 95                         hop_start_logits,  # (batch_size, max_len)
 96                         hop_end_logits,
 97                         hop_preds,   # (batch_size, 3, 3)
 98                         K_hop,       # 3
 99                         allow_limit[0],
100                     ),
101                     (
102                         ans_start_logits,
103                         ans_end_logits,
104                         ans_preds,   # (batch_size, 1, 3)
105                         K_ans,       # 1
106                         allow_limit[1],
107                     ),
108                 )
109             ):
110                 for i in range(batch_size):
111                     # 对于batch_size里的每个样本,即每个文本
112                     if sep_positions[i, 0] > 0:
113                         values, indices = start_logits[i, B_starts[i]:].topk(K)
114                         # B是文档,QA所对应的paragraph
115                         # 取出前K大的概率值以及对应的位置index
116                         for k, index in enumerate(indices):  # 3个 或 1个(answer)
117                             if values[k] <= start_logits[i, 0] - allow:  # not golden
118                                 # 小tip: start_logits[i, 0] 代表一个置信度或叫阈值
119                                 # 来判断段落内是否有有意义的“下一跳实体名称”或者“答案候选”。
120                                 if u == 1:  # For ans spans
121                                     ans_start_gap[i] = start_logits[i, 0] - values[k]
122                                 break
123                             start = index + B_starts[i]   # 输入文本中span所在的开始位置
124                             # find ending   找到span的结束位置
125                             for j, ending in enumerate(sep_positions[i]):
126                                 if ending > start or ending <= 0:
127                                     break  # 找到ending所对应的支撑句子sep位置
128                             if ending <= start:
129                                 break
130                             ending = min(ending, start + 10)
131                             end = torch.argmax(end_logits[i, start:ending]) + start
132                             # 得到end span在文本中的结束位置
133                             preds[i, k, 0] = start
134                             preds[i, k, 1] = end
135                             preds[i, k, 2] = j
136         return ((hop_loss, ans_loss, semantics)
137                 if hop_start_weights is not None
138                 else (hop_preds, ans_preds, semantics, ans_start_gap))

7. 认知图网络模块

 1 class CognitiveGCN(nn.Module):
 2     """
 3     在认知图谱上进行推理计算,使用GCN实现隐式推理计算——每一步迭代,前续节点将变换过的信息传递到下一跳节点,
 4     并更新目前的隐层表示。 在认知图谱扩展过程中,如果某被访问节点出现新的父节点(环状结构或汇集状结构),
 5     表明此点获得新的线索信息(clues),需要重新扩展计算。最终算法流程借助前沿点(frontier nodes)队列形式实现。
 6     """
 7     def __init__(self, hidden_size):
 8         super(CognitiveGCN, self).__init__()
 9         self.gcn = GCN(hidden_size)
10         self.both_net = MLP((hidden_size, hidden_size, 1))
11         self.select_net = MLP((hidden_size, hidden_size, 1))
12 
13     def forward(self, bundle, model, device):
14         batch = bundle_part_to_batch(bundle)
15         batch = tuple(t.to(device) for t in batch)
16         hop_loss, ans_loss, semantics = model(
17             *batch
18         )  # Shape of semantics: [num_para, hidden_size]
19         num_additional_nodes = len(bundle.additional_nodes)
20         if num_additional_nodes > 0:
21             max_length_additional = max([len(x) for x in bundle.additional_nodes])
22             # 取出最大长度——max_len
23             ids = torch.zeros(
24                 (num_additional_nodes, max_length_additional),
25                 dtype=torch.long,
26                 device=device,
27             )
28             segment_ids = torch.zeros(
29                 (num_additional_nodes, max_length_additional),
30                 dtype=torch.long,
31                 device=device,
32             )
33             input_mask = torch.zeros(
34                 (num_additional_nodes, max_length_additional),
35                 dtype=torch.long,
36                 device=device,
37             )
38             # 得到对应的ids, segment_ids, input_mask
39             for i in range(num_additional_nodes):
40                 length = len(bundle.additional_nodes[i])    # 对于邻接结点
41                 ids[i, :length] = torch.tensor(
42                     bundle.additional_nodes[i], dtype=torch.long
43                 )
44                 input_mask[i, :length] = 1    # mask为1 padding段相应变为0
45             additional_semantics = model(ids, segment_ids, input_mask)
46 
47             semantics = torch.cat((semantics, additional_semantics), dim=0)  # 二者相拼接
48 
49         assert semantics.size()[0] == bundle.adj.size()[0]  # 等于邻接矩阵的结点数
50 
51         if bundle.question_type == 0:  # Wh-
52             pred = self.gcn(bundle.adj.to(device), semantics)
53             ce = torch.nn.CrossEntropyLoss()
54             final_loss = ce(
55                 pred.unsqueeze(0),
56                 torch.tensor([bundle.answer_id], dtype=torch.long, device=device),
57             )
58         else:
59             x, y, ans = bundle.answer_id
60             ans = torch.tensor(ans, dtype=torch.float, device=device)
61             diff_sem = semantics[x] - semantics[y]
62             classifier = self.both_net if bundle.question_type == 1 else self.select_net
63             final_loss = 0.2 * torch.nn.functional.binary_cross_entropy_with_logits(
64                 classifier(diff_sem).squeeze(-1), ans.to(device)
65             )
66         return hop_loss, ans_loss, final_loss

具体详情待进一步补充。

 

 

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!