Getting alignment/attention during translation in OpenNMT-py

谁都会走 提交于 2019-12-02 10:04:04

You can get the attention matrices. Note that it is not the same as alignment which is a term from statistical (not neural) machine translation.

There is a thread on github discussing it. Here is a snippet from the discussion. When you get the translations from the mode, the attentions are in the attn field.

import onmt
import onmt.io
import onmt.translate
import onmt.ModelConstructor
from collections import namedtuple

# Load the model.
Opt = namedtuple('Opt', ['model', 'data_type', 'reuse_copy_attn', "gpu"])

opt = Opt("PATH_TO_SAVED_MODEL", "text", False, 0)
fields, model, model_opt =  onmt.ModelConstructor.load_test_model(
    opt, {"reuse_copy_attn" : False})

# Test data
data = onmt.io.build_dataset(
    fields, "text", "PATH_TO_DATA", None, use_filter_pred=False)
data_iter = onmt.io.OrderedIterator(
    dataset=data, device=0,
    batch_size=1, train=False, sort=False,
    sort_within_batch=True, shuffle=False)

# Translator
translator = onmt.translate.Translator(
    model, fields, beam_size=5, n_best=1,
    global_scorer=None, cuda=True)

builder = onmt.translate.TranslationBuilder(
   data, translator.fields, 1, False, None)

batch = next(data_iter)
batch_data = translator.translate_batch(batch, data)
translations = builder.from_batch(batch_data)
translations[0].attn # <--- here are the attentions
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!