论文地址为:Cognitive Graph for Multi-Hop Reading Comprehension at Scale
github地址:CogQA
背景
假设你手边有一个维基百科的搜索引擎,可以用来获取实体对应的文本段落,那么如何来回答下面这个复杂的问题呢?
“谁是某部在2003年取景于洛杉矶Quality cafe的电影的导演?”
很自然地,我们将会从例如Quality cafe这样的“相关实体”入手,通过维基百科查询相关介绍,并在其中讲到好莱坞电影的时候迅速定位到“Old School”“Gone in 60 Seconds”这两部电影,通过继续查询两部电影相关的介绍,我们找到他们的导演。最后一步是判断到底是哪位导演,这需要我们自己分析句子的语意和限定词,在了解到电影是2003年之后,我们可以做出最后判断——Todd Phillips是我们想要的答案。
事实上,“快速将注意力定位到相关实体”和“分析句子语意进行推断”是两种不同的思维过程。
在认知学里,著名的“双过程理论(dual process theory)”认为,人的认知分为两个系统,系统一(System 1)是基于直觉的、无知觉的思考系统,其运作依赖于经验和关联;而系统二(System 2)则是人类特有的逻辑推理能力,此系统利用工作记忆(working memory)中的知识进行慢速但是可靠的逻辑推理,系统二是显式的,需要意识控制的,是人类高级智能的体现。
论文详情
因此,本文提出一种新颖的迭代框架:算法使用两个系统来维护一张认知图谱(Cognitive Graph):
- 系统一在文本中抽取与问题相关的实体名称并扩展节点和汇总语义向量,
- 系统二利用图神经网络在认知图谱上进行推理计算。
正如之前提到的,人类的系统一是无知觉(unconscious),CogQA中的系统一也是流行的NLP黑盒模型,例如BERT。
在文章的实现中,系统一的输入分为三部分:
- 问题本身
- 从前面段落中找到的“线索(clues)”
- 关于某个实体x的维基百科文档
系统一的目标是抽取文档中的“下一跳实体名称(hop span)”和“答案候选(ans span)”。
这些抽取的到的实体和答案候选将作为节点添加到认知图谱中。此外,系统一还将计算当前实体 x 的语意向量,这将在系统二中用作关系推理的初始值。
模型架构图如下:
源码解析(主要是model.py文件)分为七大模块
1. 导入相应的库代码,主要是bert模块,有些库model.py没有用到,这块就不做相应解释了。(utils是文章作者写的模块)
1 from pytorch_pretrained_bert.modeling import ( 2 BertPreTrainedModel as PreTrainedBertModel, 3 BertModel, 4 BertLayerNorm, 5 gelu, 6 BertEncoder, 7 BertPooler, 8 ) 9 import torch 10 from torch import nn 11 import re 12 import pdb 13 from pytorch_pretrained_bert.tokenization import ( 14 whitespace_tokenize, 15 BasicTokenizer, 16 BertTokenizer, 17 ) 18 from utils import ( 19 fuzzy_find, 20 find_start_end_after_tokenized, 21 find_start_end_before_tokenized, 22 bundle_part_to_batch, 23 )
2. MLP模块
该模块较为简单,就是简单的多层感知机,如果大于两层,会加入相应的dropout 和 LayerNorm,并采用了bert所特有的gelu激活函数。
1 class MLP(nn.Module): 2 def __init__(self, input_sizes, dropout_prob=0.2, bias=False): 3 super(MLP, self).__init__() 4 self.layers = nn.ModuleList() 5 for i in range(1, len(input_sizes)): 6 self.layers.append(nn.Linear(input_sizes[i-1], input_sizes[i], bias=bias)) 7 self.norm_layers = nn.ModuleList() 8 if len(input_sizes) > 2: 9 for i in range(1, len(input_sizes) - 1): 10 self.norm_layers.append(nn.LayerNorm(input_sizes[i])) 11 self.drop_out = nn.Dropout(p=dropout_prob) 12 13 def forward(self, x): 14 for i, layer in enumerate(self.layers): 15 x = layer(self.drop_out(x)) 16 if i < len(self.layers) - 1: 17 x = gelu(x) 18 if len(self.norm_layers): 19 x = self.norm_layers[i](x) 20 return x
3. GCN模块
这里采用最基础的GCN,没有使用任何GCN库,速度可能较慢,但是考虑到主要时间限制在bert模型,所以这里的时间效率下降可忽略。
1 class GCN(nn.Module): 2 def init_weights(self, module): 3 if isinstance(module, (nn.Linear, nn.Embedding)): 4 module.weight.data.normal_(mean=0.0, std=0.05) 5 6 def __init__(self, input_size): 7 super(GCN, self).__init__() 8 self.diffusion = nn.Linear(input_size, input_size, bias=False) # diffusion线性变换 9 self.retained = nn.Linear(input_size, input_size, bias=False) # retaine线性变换 10 self.predict = MLP(input_sizes=(input_size, input_size, 1)) 11 self.apply(self.init_weights) # 参数矩阵赋予初始化权重(正态分布) 12 13 def forward(self, A, x): 14 layer1_diffusion = A.t().mm(gelu(self.diffusion(x))) # t() 转置 15 # A为邻接矩阵(n, n) * (n, input_size) ==> (n, input_size) 16 x = gelu(self.retained(x) + layer1_diffusion) # (n, input_size) 17 layer2_diffusion = A.t().mm(gelu(self.diffusion(x))) # (n, input_size) 18 x = gelu(self.retained(x) +layer2_diffusion) # (n, input_size) 19 return self.predict(x).sqeeze(-1) # (n, )
4. bert embedding模块 (具体见注释)
1 class BertEmbeddingsPlus(nn.Module): 2 """ 构建word embeddings, position embeddings, token_type embeddings. 3 """ 4 def __init__(self, config, max_sentence_type=30): 5 super(BertEmbeddingsPlus, self).__init__() 6 self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 7 self.position_embeddings = nn.Embedding( 8 config.max_position_embeddings, config.hidden_size 9 ) # shape (位置embedding的种类,隐层大小) 10 self.token_type_embeddings = nn.Embedding( 11 config.type_vocab_size, config.hidden_size 12 ) # (2, hidden_size) A/B segment 13 self.sentence_type_embeddings = nn.Embedding( 14 max_sentence_type, config.hidden_size 15 ) # 句子类型embedding (30, hidden_size) 16 self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) # bert LN层 17 self.dropout = nn.Dropout(config.hidden_dropout_prob) 18 19 def forward(self, input_ids, token_type_ids=None): 20 """ 21 :param input_ids: (n, seq_length) n 就是 batch_size 22 :param token_type_ids: (n, seq_length) 23 :return: 24 """ 25 seq_length = input_ids.size(1) # 文本序列长度 26 position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 27 # [5] => [0,1,2,3,4] 28 position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 29 # shape变化:(seq_length) => (1, seq_length) => (n, seq_length) 30 if token_type_ids is None: 31 token_type_ids = torch.zeros_like(input_ids) 32 33 word_embeddings = self.word_embeddings(input_ids) 34 position_embeddings = self.position_embeddings(position_ids) 35 token_type_embeddings = self.token_type_embeddings((token_type_ids > 0).long()) 36 # token_type_embeddings, 分为 A/B,segment bert输入模式 37 sentences_type_embeddings = self.sentence_type_embeddings(token_type_ids) 38 # 这才是对token_type进行embedding 39 40 embeddings = (word_embeddings + position_embeddings 41 + token_type_embeddings + sentences_type_embeddings) 42 # 四个embedding相加,充分考虑各种信息 43 embeddings = self.LayerNorm(embeddings) 44 embeddings = self.dropout(embeddings) 45 return embeddings
5 bert模型编码模块
输入input_ids,输出bert的最后的编码结果,和指定哪一层的编码结果(下层偏语法,上层偏语义)
1 class BertModelPlus(BertModel): 2 def __init__(self, config): 3 super(BertModelPlus, self).__init__() 4 self.embeddings = BertEmbeddingsPlus(config) 5 self.encoder = BertEncoder(config) # bert 编码器 6 self.pooler = BertPooler(config) # bert 池化器 7 self.apply(self.init_bert_weights) # BertModel 的初始权重参数 8 9 def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_hidden=-4): 10 if attention_mask is None: 11 attention_mask = torch.ones_like(input_ids) # (n, seq_length), n is batch_size 12 if token_type_ids is None: 13 token_type_ids = torch.zeros_like(input_ids) 14 15 extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 16 # (n, seq_length) => (n, 1, 1, seq_length) 17 18 extended_attention_mask = extended_attention_mask.to( 19 dtype=next(self.parameters()).dtype 20 ) # fp16 转换数值类型 21 extended_attention_mask = (1.0 - extended_attention_mask) * (-10000.0) 22 # 将attention_mask为0的变为-10000, 1变为0, 方便softmax求注意力后, 23 # mask为0(也就是padding)的值完全消除(=0) 24 # 例如[1, 1, 1, 0, 0] ==> (0, 0, 0, -10000, -10000) 25 26 embedding_output = self.embeddings(input_ids, token_type_ids) 27 encoded_layers = self.encoder( 28 embedding_output, extended_attention_mask, output_all_encoded_layers=True) 29 # sequence_output = encoded_layers[-1] 30 # pooled_output = self.pooler(sequence_output) 31 encoded_layers, hidden_layers = ( 32 encoded_layers[-1], encoded_layers[output_hidden] # -4 倒数第四层 33 ) 34 return encoded_layers, hidden_layers # shape (batch_size, hidden_size), (batch_size, hidden_size)
6. 多跳阅读理解模块
也就是论文中提到的系统一,即“下一跳实体名称”和“答案候选”的抽取,是通过预测每个位置是否是span开始或者结束的概率来确定,与BERT原文中的做法相同;
其中几个值得注意的细节,比如之所以将“下一跳实体名称”和“答案候选”分开,是因为前者更多关注语意相关性而后者则需要匹配疑问词;
第0个位置的输出被用来产生一个阈值,判断段落内是否包含有意义的“下一跳实体名称”或者“答案候选”。
该模块就是常规的抽取式阅读理解,解码出span的代码部分较为复杂~
注:这里没有用到GCN,GCN是用来作为系统二进行推理的模块。
1 class BertForMultiHopQuestionAnswering(PreTrainedBertModel): 2 def __init__(self, config): 3 super(BertForMultiHopQuestionAnswering, self).__init__() 4 self.bert = BertModelPlus(config) 5 self.qa_outputs = nn.Linear(config.hidden_size, 4) 6 self.apply(self.init_bert_weights) # PreTrainedBertModel 初始化权重 7 8 def forward(self, input_ids, 9 token_type_ids=None, 10 attention_mask=None, 11 sep_positions=None, 12 hop_start_weights=None, 13 hop_end_weights=None, 14 ans_start_weights=None, 15 ans_end_weights=None, 16 B_starts=None, 17 allow_limit=(0, 0), 18 ): 19 """ 20 从系统1抽取span (分为两个系统,具体看原文) 21 :param input_ids: LongTensor 22 (batch_size, max_len) 23 :param token_type_ids: LongTensor 24 The A/B Segmentation in BERTs. (batch, maxlen) 25 :param attention_mask: LongTensor 26 指示该位置是token还是padding (batch_size, maxlen) 27 :param sep_positions: LongTensor 28 [SEP]的具体位置 主要用来发现支持段落的句子 (batch_size, max_seps) 29 :param hop_start_weights: Tensor(默认为FloatTensor) 30 hop开始位置的标注情况 31 :param hop_end_weights: Tensor 32 hop结束位置的标注情况 (ground truth) 33 :param ans_start_weights: Tensor 34 答案标注开始位置的可能性(概率) 35 :param ans_end_weights: Tensor 36 答案标注结束位置的可能性(概率) 37 :param B_starts: 38 句子B的开始位置 39 :param allow_limit: 40 An Offset for negative threshold (负阈值的偏移量) 41 :return: 42 """ 43 batch_size = input_ids.size()[0] 44 device = input_ids.get_device() if input_ids.is_cuda else torch.device('cpu') 45 sequence_output, hidden_output = self.bert(input_ids, token_type_ids, attention_mask) 46 # 上面两者的shape都为: (batch_size, max_len, hidden_size) 47 semantics = hidden_output[:, 0] # shape: (batch_size, hidden_size) 48 49 if sep_positions is None: 50 return semantics # 仅仅语义信息 51 else: 52 max_sep = sep_positions.size()[-1] # max_seps 53 if max_sep == 0: 54 empty = torch.zeros(batch_size, 0, dtype=torch.long, device=device) # mistake 55 return ( 56 empty, 57 empty, 58 semantics, 59 empty, 60 ) # Only semantics, used in eval, the same ``empty'' variable is a mistake in general cases but simple 61 62 # 预测span 63 logits = self.qa_outputs(sequence_output) 64 hop_start_logits, hop_end_logits, ans_start_logits, ans_end_logits = logits.split( 65 split_size=1, dim=-1 # 前面的1代表单个分块的形状大小 66 ) # 每个的形状为 (batch_size, max_len, 1) 67 hop_start_logits = hop_start_logits.squeeze(-1) 68 hop_end_logits = hop_end_logits.squeeze(-1) 69 ans_start_logits = ans_start_logits.squeeze(-1) 70 ans_end_logits = ans_end_logits.squeeze(-1) # Shape: [batch_size, max_len] 71 72 if hop_start_weights is not None: # train mode (因为提供了标签信息:hop_start_weights等) 73 lgsf = nn.LogSoftmax(dim=1) 74 # 如果句子中没有目标span,start_weights = end_weights = 0(tensor) 75 # 以下四个求二元交叉熵loss 76 hop_start_loss = -torch.sum(hop_start_weights * lgsf(hop_start_logits), dim=-1) 77 hop_end_loss = -torch.sum(hop_end_weights * lgsf(hop_end_logits), dim=-1) 78 ans_start_loss = -torch.sum(ans_start_weights * lgsf(ans_start_logits), dim=-1) 79 ans_end_loss = -torch.sum(ans_end_weights * lgsf(ans_end_logits), dim=-1) 80 81 hop_loss = torch.mean((hop_start_loss + hop_end_loss)) / 2 82 ans_loss = torch.mean((ans_start_loss + ans_end_loss)) / 2 83 84 else: 85 K_hop, K_ans = 3, 1 86 hop_preds = torch.zeros(batch_size, K_hop, 3, dtype=torch.long, device=device) 87 # (batch_size, 3, 3) 88 ans_preds = torch.zeros(batch_size, K_ans, 3, dtype=torch.long, device=device) 89 # (batch_size, 1, 3) 90 91 ans_start_gap = torch.zeros(batch_size, device=device) 92 for u, (start_logits, end_logits, preds, K, allow) in enumerate( 93 ( 94 ( 95 hop_start_logits, # (batch_size, max_len) 96 hop_end_logits, 97 hop_preds, # (batch_size, 3, 3) 98 K_hop, # 3 99 allow_limit[0], 100 ), 101 ( 102 ans_start_logits, 103 ans_end_logits, 104 ans_preds, # (batch_size, 1, 3) 105 K_ans, # 1 106 allow_limit[1], 107 ), 108 ) 109 ): 110 for i in range(batch_size): 111 # 对于batch_size里的每个样本,即每个文本 112 if sep_positions[i, 0] > 0: 113 values, indices = start_logits[i, B_starts[i]:].topk(K) 114 # B是文档,QA所对应的paragraph 115 # 取出前K大的概率值以及对应的位置index 116 for k, index in enumerate(indices): # 3个 或 1个(answer) 117 if values[k] <= start_logits[i, 0] - allow: # not golden 118 # 小tip: start_logits[i, 0] 代表一个置信度或叫阈值 119 # 来判断段落内是否有有意义的“下一跳实体名称”或者“答案候选”。 120 if u == 1: # For ans spans 121 ans_start_gap[i] = start_logits[i, 0] - values[k] 122 break 123 start = index + B_starts[i] # 输入文本中span所在的开始位置 124 # find ending 找到span的结束位置 125 for j, ending in enumerate(sep_positions[i]): 126 if ending > start or ending <= 0: 127 break # 找到ending所对应的支撑句子sep位置 128 if ending <= start: 129 break 130 ending = min(ending, start + 10) 131 end = torch.argmax(end_logits[i, start:ending]) + start 132 # 得到end span在文本中的结束位置 133 preds[i, k, 0] = start 134 preds[i, k, 1] = end 135 preds[i, k, 2] = j 136 return ((hop_loss, ans_loss, semantics) 137 if hop_start_weights is not None 138 else (hop_preds, ans_preds, semantics, ans_start_gap))
7. 认知图网络模块
1 class CognitiveGCN(nn.Module): 2 """ 3 在认知图谱上进行推理计算,使用GCN实现隐式推理计算——每一步迭代,前续节点将变换过的信息传递到下一跳节点, 4 并更新目前的隐层表示。 在认知图谱扩展过程中,如果某被访问节点出现新的父节点(环状结构或汇集状结构), 5 表明此点获得新的线索信息(clues),需要重新扩展计算。最终算法流程借助前沿点(frontier nodes)队列形式实现。 6 """ 7 def __init__(self, hidden_size): 8 super(CognitiveGCN, self).__init__() 9 self.gcn = GCN(hidden_size) 10 self.both_net = MLP((hidden_size, hidden_size, 1)) 11 self.select_net = MLP((hidden_size, hidden_size, 1)) 12 13 def forward(self, bundle, model, device): 14 batch = bundle_part_to_batch(bundle) 15 batch = tuple(t.to(device) for t in batch) 16 hop_loss, ans_loss, semantics = model( 17 *batch 18 ) # Shape of semantics: [num_para, hidden_size] 19 num_additional_nodes = len(bundle.additional_nodes) 20 if num_additional_nodes > 0: 21 max_length_additional = max([len(x) for x in bundle.additional_nodes]) 22 # 取出最大长度——max_len 23 ids = torch.zeros( 24 (num_additional_nodes, max_length_additional), 25 dtype=torch.long, 26 device=device, 27 ) 28 segment_ids = torch.zeros( 29 (num_additional_nodes, max_length_additional), 30 dtype=torch.long, 31 device=device, 32 ) 33 input_mask = torch.zeros( 34 (num_additional_nodes, max_length_additional), 35 dtype=torch.long, 36 device=device, 37 ) 38 # 得到对应的ids, segment_ids, input_mask 39 for i in range(num_additional_nodes): 40 length = len(bundle.additional_nodes[i]) # 对于邻接结点 41 ids[i, :length] = torch.tensor( 42 bundle.additional_nodes[i], dtype=torch.long 43 ) 44 input_mask[i, :length] = 1 # mask为1 padding段相应变为0 45 additional_semantics = model(ids, segment_ids, input_mask) 46 47 semantics = torch.cat((semantics, additional_semantics), dim=0) # 二者相拼接 48 49 assert semantics.size()[0] == bundle.adj.size()[0] # 等于邻接矩阵的结点数 50 51 if bundle.question_type == 0: # Wh- 52 pred = self.gcn(bundle.adj.to(device), semantics) 53 ce = torch.nn.CrossEntropyLoss() 54 final_loss = ce( 55 pred.unsqueeze(0), 56 torch.tensor([bundle.answer_id], dtype=torch.long, device=device), 57 ) 58 else: 59 x, y, ans = bundle.answer_id 60 ans = torch.tensor(ans, dtype=torch.float, device=device) 61 diff_sem = semantics[x] - semantics[y] 62 classifier = self.both_net if bundle.question_type == 1 else self.select_net 63 final_loss = 0.2 * torch.nn.functional.binary_cross_entropy_with_logits( 64 classifier(diff_sem).squeeze(-1), ans.to(device) 65 ) 66 return hop_loss, ans_loss, final_loss
具体详情待进一步补充。