深度注意力机制Deep Attention Matching Network

柔情痞子 提交于 2020-04-27 15:35:30


深度注意力机制模型(Deep Attention Matching Network)是开放领域多轮对话匹配模型。根据多轮对话历史和候选回复内容,排序出最合适的回复。 网络结构如下 







训练、预测、评估使用的数据示例如下,数据由三列组成,以制表符('\t')分隔,第一列是以空 格分开的上文id,第二列是以空格分开的回复id,第三列是标签

286 642 865 36    87 25 693       0
17 54 975         512 775 54 6    1

注:本项目额外提供了分词预处理脚本(在preprocess目录下),可供用户使用,具体使用方法如 下:

python tokenizer.py \
  --test_data_dir ./test.txt.utf8 \
  --batch_size 1 > test.txt.utf8.seg









deep_attention_matching/: 存放深度注意力机制模型的主要执行文件





!cd data/data11447/ && unzip -qo ubuntu.zip
!python deep_attention_matching/main.py \
  --do_train True \
  --use_cuda \
  --data_path data/data11447/ubuntu/data_small.pkl \
  --save_path deep_attention_matching/model_files/ubuntu \
  --vocab_size 434512 \
  --_EOS_ 28270 \
  --batch_size 32
-----------  Configuration Arguments -----------
_EOS_: 28270
batch_size: 32
channel1_num: 32
channel2_num: 16
data_path: data/data11447/ubuntu/data_small.pkl
do_infer: False
do_train: True
emb_size: 200
ext_eval: False
learning_rate: 0.001
max_turn_len: 50
max_turn_num: 9
model_path: None
num_scan_data: 2
save_path: deep_attention_matching/model_files/ubuntu
stack_num: 5
use_cuda: True
vocab_size: 434512
word_emb_init: None
begin memory optimization ...
2019-09-05 15:58:35
end memory optimization ...
2019-09-05 15:58:35
device count 1
theoretical memory usage: 
(8378.70401058197, 8777.689915847779, 'MB')
W0905 15:58:37.603806  1186 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W0905 15:58:37.611320  1186 device_context.cc:267] device: 0, cuDNN Version: 7.3.
     You can try our memory optimize feature to save your memory usage:
         # create a build_strategy variable to set memory optimize option
         build_strategy = compiler.BuildStrategy()
         build_strategy.enable_inplace = True
         build_strategy.memory_optimize = True
         # pass the build_strategy to with_data_parallel API
         compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
             loss_name=loss.name, build_strategy=build_strategy)
     !!! Memory optimize is our experimental feature !!!
         some variables may be removed/reused internal to save memory usage, 
         in order to fetch the right value of the fetch_list, please set the 
         persistable property to true for each variable in fetch_list

         # Sample
         conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) 
         # if you need to fetch conv1, then:
         conv1.persistable = True

I0905 15:58:37.884307  1186 parallel_executor.cc:329] The number of CUDAPlace, which is used in ParallelExecutor, is 1. And the Program will be copied 1 copies
I0905 15:58:38.161104  1186 build_strategy.cc:340] SeqOnlyAllReduceOps:0, num_trainers:1
     You can try our memory optimize feature to save your memory usage:
         # create a build_strategy variable to set memory optimize option
         build_strategy = compiler.BuildStrategy()
         build_strategy.enable_inplace = True
         build_strategy.memory_optimize = True
         # pass the build_strategy to with_data_parallel API
         compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
             loss_name=loss.name, build_strategy=build_strategy)
     !!! Memory optimize is our experimental feature !!!
         some variables may be removed/reused internal to save memory usage, 
         in order to fetch the right value of the fetch_list, please set the 
         persistable property to true for each variable in fetch_list

         # Sample
         conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) 
         # if you need to fetch conv1, then:
         conv1.persistable = True

share_vars_from is set, scope is ignored.
I0905 15:58:38.416013  1186 parallel_executor.cc:329] The number of CUDAPlace, which is used in ParallelExecutor, is 1. And the Program will be copied 1 copies
I0905 15:58:38.489773  1186 build_strategy.cc:340] SeqOnlyAllReduceOps:0, num_trainers:1
start loading data ...
finish loading data ...
begin model training ...
2019-09-05 15:58:39
[1676 9116 5609 ... 1722 2436 5949]
processed: [0.00961538461538] ave loss: [0.7781140208244324]
processed: [0.0192307692308] ave loss: [0.7981151739756266]
processed: [0.0288461538462] ave loss: [0.6933611432711283]
processed: [0.0384615384615] ave loss: [0.7092911005020142]
!python deep_attention_matching/main.py \
  --do_infer True \
  --use_cuda \
  --data_path ./data/data11447/ubuntu/data_small.pkl \
  --save_path deep_attention_matching/infer_result \
  --model_path deep_attention_matching/model_files/ubuntu/ \
  --vocab_size 434512 \
  --_EOS_ 28270 \
  --batch_size 1
-----------  Configuration Arguments -----------
_EOS_: 28270
batch_size: 1
channel1_num: 32
channel2_num: 16
data_path: ./data/data11447/ubuntu/data_small.pkl
do_infer: True
do_train: False
emb_size: 200
ext_eval: False
learning_rate: 0.001
max_turn_len: 50
max_turn_num: 9
model_path: deep_attention_matching/model_files/ubuntu/
num_scan_data: 2
save_path: deep_attention_matching/infer_result
stack_num: 5
use_cuda: True
vocab_size: 434512
word_emb_init: None
W0905 15:46:10.321343   919 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W0905 15:46:10.325507   919 device_context.cc:267] device: 0, cuDNN Version: 7.3.
start loading data ...
finish loading data ...
test batch num: 1000
begin inference ...
2019-09-05 15:46:11
('turns:', array([[[393704,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [250191,  34296, 350284,  30835,  59150,  74395,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [180037,  28847,  88281, 115692, 413324, 279504, 354176,  20481,
         418397, 418397, 177048, 197682, 115692, 373516, 192382, 285320,
          20484,  20494, 229901,   9751,  20494,  11317,  20484, 347085,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [ 97341, 291041,  14781, 414881, 126529, 174798,   1828, 324795,
         324507, 227764,  20484,  54259,   7198, 296758, 259553, 354176,
         123155,  20484, 149834, 343709,    238,  20484, 106788,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [131277, 326026, 146729, 170184, 180037, 418453,  20484,  20494,
         229901,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [131277, 326026, 207572, 307055, 284678, 285320, 180037, 404776,
         364101,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [395666,  75844, 233777, 195724,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [ 81947, 427943, 257613,  20484, 373516,   1598, 395666, 233777,
         195724,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [262553, 368624,  15331, 115692, 107043, 343709,  27044, 307801,
         367146, 309360,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0]]]))
('response:', array([[[131277],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0],
        [     0]]]))
('scores:', array([[-0.2420864]], dtype=float32))
finish test
2019-09-05 15:46:36

点击链接,使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/122287 
