在工业界DSSM(Deep Structured Semantic Models)已经演化成一种语义匹配框架,不仅用于文本的匹配,也用于推荐系统的User-Item的匹配,本文描述与实现DSSM在文本匹配上的应用,主要内容如下:
- DSSM原理
- 数据预处理
- 模型实现
- 模型训练
- 模型预测
DSSM原理
DSSM的主要结构如下:
主要分为表示层和匹配层,表示层可使用全连接、RNN、Transformer等等网络得到query和doc的向量,匹配层一般使用cosine相似度来计算query和1个正样本doc和N个负样本doc的相似度。这里就不讲解原始论文里DSSM的原理了,只讲其在文本上是如何使用的,有兴趣的朋友可以参考原论文《Learning Deep Structured Semantic Models for Web Search using Clickthrough Data》。
- 表示层
原始论文中主要针对英文,为了降维做了word hashing,而中文常用汉字只有1万左右。我们将query和doc的字直接传入embedding,然后接一层双向的GRU,假设每个字的embedding表示为,则GRU表示为:
然而,并不是每个字都是我们所需要的,所以在GRU后面再接一层Attention,其表示为:
即通过一个线性层对GRU的输出进行变换,然后通过softmax公式计算出每个字的重要程度,最后与GRU的输出求和得到句子的表示。
- 匹配层
得到query和doc的向量表示之后(一般为64或128维向量),计算他们之间的cosine相似度:
并通过softmax将query与正样本doc的相似度计算转换为后验概率,计算如下:
其中为softmax平滑因子,可设置为固定值(如20,50等)也可以用一个参数去学习。D+为query对应的正样本,D'为随机采样的query对应的N个负样本。
训练时通过极大似然估计最小化损失函数:
数据预处理
这里我们采用某新闻语料,训练样本格式为每行:标题\t正文,作为一条样本。
- 将句子转换成ID序列
def sent2id(sent, vocab, max_size=30):
sent = [vocab[c] for c in sent if c in vocab]
sent = sent[:max_size] + [0]*(max_size - len(sent))
return sent
其中vocab为词表,其格式为{"<UNK>":0, "<PAD>": 1, "的":2, ...}
- 转换为TFRecord格式
为了加快数据读取速度,将样本转换成TFRecord格式:
def create_int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def convert_tfrecord(in_file, out_file, vocab_path, query_size=50, doc_size=200):
vocab = json.load(codecs.open(vocab_path, "r", "utf-8"))
writer = tf.io.TFRecordWriter(out_file)
icount = 0
with codecs.open(in_file, "r", "utf-8") as fr:
for line in tqdm(fr):
icount += 1
line = line.strip().split("\t")
query = sent2id(line[0], vocab, query_size)
doc = sent2id(line[1], vocab, doc_size)
feed_dict = {"query_char": create_int_feature(query),
"doc_char": create_int_feature(doc),
"label": create_int_feature([1])}
example = tf.train.Example(features=tf.train.Features(feature=feed_dict))
serialized = example.SerializeToString()
writer.write(serialized)
print(icount)
writer.close()
模型实现
- Embedding层
query和doc都是文本,所以复用Embedding层的权重:
def word_embedding(inputs, reuse=None, vocab_size=10000, embedding_size=128, scope_name="char_embedding"):
with tf.variable_scope(scope_name, reuse=reuse):
embedding_matrix = tf.Variable(tf.truncated_normal((vocab_size, embedding_size)))
embedding = tf.nn.embedding_lookup(embedding_matrix, inputs, name=scope_name + "_layer")
embedding = tf.nn.tanh(embedding)
return embedding
- GRU层
这里采用双向的GRU:
def compute_seq_length(sequences):
used = tf.sign(tf.reduce_max(tf.abs(sequences), reduction_indices=2))
seq_len = tf.reduce_sum(used, reduction_indices=1)
return tf.cast(seq_len, tf.int32)
def rnn_encoder(inputs, reuse, scope_name):
with tf.variable_scope(scope_name, reuse=reuse):
GRU_cell_fw = tf.contrib.rnn.GRUCell(FLAGS.rnn_hidden_size)
GRU_cell_bw = tf.contrib.rnn.GRUCell(FLAGS.rnn_hidden_size)
((fw_outputs, bw_outputs), (_, _)) = tf.nn.bidirectional_dynamic_rnn(cell_fw=GRU_cell_fw,
cell_bw=GRU_cell_bw,
inputs=inputs,
sequence_length=compute_seq_length(inputs),
dtype=tf.float32)
outputs = tf.concat((fw_outputs, bw_outputs), axis=2)
return outputs
- Attention层
def attention_layer(inputs, reuse, scope_name, outname):
with tf.variable_scope(scope_name, reuse=reuse):
u_context = tf.Variable(tf.truncated_normal([FLAGS.rnn_hidden_size * 2]), name=scope_name+ '_u_context')
h = tf.contrib.layers.fully_connected(inputs, FLAGS.rnn_hidden_size * 2, activation_fn=tf.nn.tanh)
alpha = tf.nn.softmax(tf.reduce_sum(tf.multiply(h, u_context), axis=2, keepdims=True), axis=1)
attn_output = tf.reduce_sum(tf.multiply(inputs, alpha), axis=1, name=outname)
return attn_output
- 句子表示
将GRU层再传入Attention层即可得到句子的表示:
def sentence_embedding(inputs, reuse=None, max_sentence_length=50, scope_name="char_sent"):
with tf.variable_scope(scope_name, reuse=reuse):
embedding = tf.reshape(inputs, [-1, max_sentence_length, FLAGS.char_embedding_size])
word_encoder = rnn_encoder(embedding, reuse, scope_name=scope_name + "_encoder_layer")
sent_encoder = attention_layer(word_encoder, reuse=reuse, scope_name=scope_name+"_attention_layer", outname=scope_name+"_vec")
sent_encoder = tf.nn.tanh(sent_encoder)
return sent_encoder
- query的表示
def build_query_model(features, mode):
# 输入shape: [batch_size, sentence_size]
char_input = tf.reshape(features["query_char"], [-1, FLAGS.query_max_char_length])
char_embed = word_embedding(char_input, None, FLAGS.char_vocab_size, FLAGS.char_embedding_size, "char_embedding")
sent_encoder = sentence_embedding(char_embed,
None,
FLAGS.query_max_char_length,
"char_sent")
sent_encoder = tf.layers.dense(sent_encoder, units=FLAGS.last_hidden_size, activation=tf.nn.tanh, name="query_encoder")
sent_encoder = tf.nn.l2_normalize(sent_encoder)
return sent_encoder
- doc的表示
def build_doc_model(features, mode):
# 输入shape: [batch_size, sentence_size]
char_input = tf.reshape(features["doc_char"], [-1, FLAGS.doc_max_char_length])
char_embed = word_embedding(char_input, True, FLAGS.char_vocab_size, FLAGS.char_embedding_size, "char_embedding")
sent_encoder = sentence_embedding(char_embed,
True if mode==tf.estimator.ModeKeys.TRAIN else tf.AUTO_REUSE,
FLAGS.doc_max_char_length,
"char_sent")
sent_encoder = tf.layers.dense(sent_encoder, units=FLAGS.last_hidden_size, activation=tf.nn.tanh, name="doc_encoder")
sent_encoder = tf.nn.l2_normalize(sent_encoder)
return sent_encoder
模型训练
- 计算cosine相似度
with tf.name_scope("fd-rotate"):
tmp = tf.tile(doc_encoder, [1, 1])
doc_encoder_fd = doc_encoder
for i in range(FLAGS.NEG):
rand = random.randint(1, FLAGS.batch_size + i) % FLAGS.batch_size
s1 = tf.slice(tmp, [rand, 0], [FLAGS.batch_size - rand, -1])
s2 = tf.slice(tmp, [0, 0], [rand, -1])
doc_encoder_fd = tf.concat([doc_encoder_fd, s1, s2], axis=0)
query_norm = tf.tile(tf.sqrt(tf.reduce_sum(tf.square(query_encoder), axis=1, keepdims=True)), [FLAGS.NEG + 1, 1])
doc_norm = tf.sqrt(tf.reduce_sum(tf.square(doc_encoder_fd), axis=1, keepdims=True))
query_encoder_fd = tf.tile(query_encoder, [FLAGS.NEG + 1, 1])
prod = tf.reduce_sum(tf.multiply(query_encoder_fd, doc_encoder_fd), axis=1, keepdims=True)
norm_prod = tf.multiply(query_norm, doc_norm)
cos_sim_raw = tf.truediv(prod, norm_prod)
cos_sim = tf.transpose(tf.reshape(tf.transpose(cos_sim_raw), [FLAGS.NEG + 1, -1])) * 20
计算cosine相似度时难点在于需要在每个batch里随机采样作为负样本并分别与query计算cosine。
- 损失函数
with tf.name_scope("loss"):
prob = tf.nn.softmax(cos_sim)
hit_prob = tf.slice(prob, [0, 0], [-1, 1])
loss = -tf.reduce_mean(tf.log(hit_prob))
correct_prediction = tf.cast(tf.equal(tf.argmax(prob, 1), 0), tf.float32)
accuracy = tf.reduce_mean(correct_prediction)
- 传入estimator训练
classifier = tf.estimator.Estimator(model_fn=model.model_fn,
config=tf.estimator.RunConfig(model_dir=FLAGS.model_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=3),
params={}
)
def train_eval_model():
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_utils.train_input_fn(FLAGS.train_data, FLAGS.batch_size),
max_steps=FLAGS.train_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_utils.eval_input_fn(FLAGS.eval_data, FLAGS.batch_size),
start_delay_secs=60,
throttle_secs = 30,
steps=1000)
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
模型预测
- 模型导出
模型训练完我们会将模型导出为pb格式,方便预测使用,导出模型代码如下:
def export_model(feed_dict, export_dir):
feature_map = dict()
for key, value in feed_dict.items():
feature_map[key] = tf.placeholder(dtype=tf.int64, shape=[None, value], name=key)
serving_input_recevier_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map)
export_dir = classifier.export_saved_model(export_dir, serving_input_recevier_fn)
- 模型预测
加载pb格式的模型:
model = predictor.from_saved_model(model_path)
抽取向量:
def get_vector(sentence, model):
feed_dict = {"query_char": [sent2id(sentence)]}
vector = model(feed_dict)
return vector["query_vector"][0]
是的,预测就是这么简单!!
我们看看效果,下面的例子是输入一个句子然后在10000个句子中找出最相似的10个句子:
输入:祝考研的女士们先生们都顺利考进自己理想的学校
输出:
0.890815 祝考研的女士们先生们都顺利考进自己理想的学校!实在考不上就滚tm的,当代...
0.758741 硕士研究生招生考试22日开考
0.701588 加油高考!祝你们顺利考上心仪的大学!
0.660756 中考,你准备好了吗?
0.654576 这些考研复试面试小技巧收好,导师的心就抓住了!
0.63505 高考生作弊被抓飞踹监考老师:你知道我爸是谁?
0.626651 高考倒计时30天,祝所有今年参加高考的小伙伴们心想事成,高考必胜
0.590912 各位同学请注意,第一季期末考试现在开始~请认真阅读仔细答题
0.585147 航班延误艺考生妈妈痛哭 浙传:可提供证明安排考试
0.575564 当女儿带男同学回家写作业的时候,爸爸都在想什么
是不是看起来还像那么回事?
完整代码已开源,https://github.com/cdj0311/dssm 阅读到的朋友记得点个star。
来源:oschina
链接:https://my.oschina.net/u/4383327/blog/4450015