ValueError: Expected target size (64, 31), got torch.Size([64, 63])

折月煮酒 提交于 2020-01-27 18:49:52

Pred-Esti介绍

Preictor-Estimator是一个两阶段的神经质量评估模型,它包括两个神经模型:

  • a predictor:词预测器,使用额外的大规模平行语料进行训练
  • an estimator:质量评估器,使用质量标注了的平行语料(QE data)训练

问题

在训练estimator模型时出现了这样的问题:
在这里插入图片描述
这是因为pred_tagstags的维度不同的。
在这里插入图片描述
如图所示:使用的wmt19数据中给出的标注的tags文件中包括MT tagsGap tags,而预测的tags中只有MT tags

pred_tags = [1, 0, 0, 0, 0, 0, 1, 0, 0]		#	OK:0;BAD:1
tags = [BAD, BAD, OK, BAD, BAD, OK, OK, OK, OK, OK, OK, OK, OK, BAD, OK, BAD, OK, OK, OK]

解决方法

# path_tags:MT tags + Gap tags
# path_target_tags:生成的MT tags
path_tags = "openkiwi/data/WMT19/wordsent_level/dev.tags"
path_target_tags = "openkiwi/data/WMT19/wordsent_level/dev.target_tags"

def target_tags(path1, path2):
    with open(path1, "r") as file:
        for line in file:
            array = line.strip().split(" ")[1::2]
            string = ' '.join(array)
            with open(path2, "a") as f:
                f.write(string + '\n')

target_tags(path_tags, path_target_tags)
pred_tags = [1, 0, 0, 0, 0, 0, 1, 0, 0]		#	OK:0;BAD:1
target_tags = [BAD, BAD, OK, OK, OK, OK, BAD, BAD, OK]
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!