菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(九)—— 预测与校验

系列目录:

  1. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)——数据
  2. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)——
    介绍及分词
  3. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(三)—— 预处理
  4. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(四)—— 段落抽取
  5. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(五)—— 准备数据
  6. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(六)—— 模型构建
  7. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(七)—— 模型训练-数据准备
  8. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练

未完待续 … …

上一篇文章对模型进行了训练,模型训练后就是使用模型对新的数据进行预测以及在测试集上进行验证了。

验证函数

模型验证的代码在run.py文件中,为evaluate函数,具体代码如下:

def evaluate(args):
    """
    在验证集上对训练好的模型进行验证
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    # 加载词典
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.dev_files) > 0, 'No dev files are provided.'
    # 加载数据集
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files)
    logger.info('Converting text into ids...')
    # 将数据集中文本转换为ids
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    # 构建模型
    rc_model = RCModel(vocab, args)
    # 加载已经训练好的模型
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Evaluating the model on dev set...')
    # 生成批次数据
    dev_batches = brc_data.gen_mini_batches('dev', args.batch_size,
                                            pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    # 校验
    dev_loss, dev_bleu_rouge = rc_model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')
    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))

有代码可知,这部分代码与train中的代码大部分相同,不同点在创建模型后加载了训练好的模型,最后是调用了rc_model.evaluate函数对模型进行验证而不是训练。rc_model.evaluate函数在系列笔记(八)中进行了介绍,此处不再赘述。
另外,从rc_model.evaluate函数的输入参数result_dirresult_prefix可知,验证时生成的结果存放在了args.result_dir下的dev.predicted.json文件中。

预测函数

预测的代码在run.py文件中,为predict函数,具体代码如下:

def predict(args):
    """
    为测试数据预测答案
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    # 加载词典
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    assert len(args.test_files) > 0, 'No test files are provided.'
    # 加载测试数据
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          test_files=args.test_files)
    logger.info('Converting text into ids...')
    # 将数据中文本转化为ids
    brc_data.convert_to_ids(vocab)
    logger.info('Restoring the model...')
    # 加载模型
    rc_model = RCModel(vocab, args)
    rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info('Predicting answers for test set...')
    # 生成数据批次
    test_batches = brc_data.gen_mini_batches('test', args.batch_size,
                                             pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
    # 预测
    rc_model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix='test.predicted')

由代码可见,预测函数与验证函数的代码基本一致,只是参数不同,由于输入没有真实答案,也没有计算评估分数。从rc_model.evaluate函数的输入参数result_dirresult_prefix可知,预测时生成的结果存放在了args.result_dir下的test.predicted.json文件中。

运行

import

import sys
import pickle
from run import *
WARNING:tensorflow:
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

args

import logging
# 准备参数
sys.argv = []
args = parse_args()
# 设定日志输出
logger = logging.getLogger("brc")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if args.log_path:
    file_handler = logging.FileHandler(args.log_path)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
else:
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

logger.info('Running with args : {}'.format(args))
2020-03-28 00:04:21,919 - brc - INFO - Running with args : Namespace(algo='BIDAF', batch_size=32, brc_dir='../data/baidu', dev_files=['../data/demo/devset/search.dev.json'], dropout_keep_prob=1, embed_size=300, epochs=10, evaluate=False, gpu='0', hidden_size=150, learning_rate=0.001, log_path=None, max_a_len=200, max_p_len=500, max_p_num=5, max_q_len=60, model_dir='../data/models/', optim='adam', predict=False, prepare=False, result_dir='../data/results/', summary_dir='../data/summary/', test_files=['../data/demo/testset/search.test.json'], train=False, train_files=['../data/demo/trainset/search.train.json'], vocab_dir='../data/vocab/', weight_decay=0)
#设定运行环境
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

evaluate

调用校验函数对模型进行校验。

evaluate(args)
2020-03-28 00:04:22,279 - brc - INFO - Load data_set and vocab...
2020-03-28 00:04:22,426 - brc - INFO - Dev set size: 100 questions.
2020-03-28 00:04:22,428 - brc - INFO - Converting text into ids...
2020-03-28 00:04:22,446 - brc - INFO - Restoring the model...

2020-03-28 00:04:26,137 - brc - INFO - Time to build graph: 3.4807958602905273 s
2020-03-28 00:04:34,473 - brc - INFO - There are 4995603 parameters in the model

2020-03-28 00:04:35,917 - brc - INFO - Model restored from ../data/models/, with prefix BIDAF
2020-03-28 00:04:35,919 - brc - INFO - Evaluating the model on dev set...
2020-03-28 00:04:41,299 - brc - INFO - Saving dev.predicted results to ../data/results/dev.predicted.json

2020-03-28 00:04:41,903 - brc - INFO - Loss on dev set: 13.766514587402344
2020-03-28 00:04:41,904 - brc - INFO - Result on dev set: {'Bleu-1': 0.19132287017014682, 'Bleu-2': 0.14104624846358704, 'Bleu-3': 0.11170956630975126, 'Bleu-4': 0.09381119777647412, 'Rouge-L': 0.19897168820075256}
2020-03-28 00:04:41,905 - brc - INFO - Predicted answers are saved to ../data/results/

从输出结果可以看到loss是13.7,Bleu-nRouge-L等分数也很低,这是因为加载的模型只是在示例数据上训练,数据量很小,模型的训练效果很差。

接下来可以看一下具体的问题、真实答案、预测答案对比:

# 加载校验数据
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files)
# 查看校验数据集属性
brc_data.dev_set[0].keys()
2020-03-28 00:10:19,908 - brc - INFO - Dev set size: 100 questions.
dict_keys(['documents', 'answer_spans', 'answer_docs', 'fake_answers', 'question', 'segmented_answers', 'answers', 'entity_answers', 'segmented_question', 'question_type', 'match_scores', 'fact_or_opinion', 'question_id', 'answer_passages', 'question_tokens', 'passages'])
# 加载验证数据预测结果
import json
result_file = os.path.join(args.result_dir, 'dev.predicted' + '.json')
with open(result_file, 'r') as f:
    dev_res = f.readlines()

# 定义样本输出的字段
def sample_info(i):
    print('question_id:  ',brc_data.dev_set[i]['question_id'])
    print('question:     ',brc_data.dev_set[i]['question'])
    print('answers:      ',brc_data.dev_set[i]['answers'])
    print('fake_answers: ',brc_data.dev_set[i]['fake_answers'])

输出样本0问题、答案、伪答案信息:

s_i = 0
sample_info(s_i)
question_id:   186572
question:      2017有什么好看的小说
answers:       ['1.《将夜》2.《择天记》3.《冒牌大英雄》4.《无限恐怖》5.《恐怖搞校》6.《大国医》7.《龙魔导》。', '《大唐悬疑录:长恨歌密码》、《风雪追击》、《草原动物园》、《有匪2:离恨楼》。', '我们住在一起、月都花落,沧海花开、天定风华、寻找爱情的邹小姐、应许之日、星光的彼端、他来了,请闭眼。']
fake_answers:  ['《大唐悬疑录》中的裴玄静与《有匪》']

其对应预测结果:

json.loads(dev_res[s_i])
{'question_id': 186572,
 'question_type': 'ENTITY',
 'answers': ['2017年的第一季度刚刚过去,大家在这三个月里有读到什么好书吗?培根曾说过:“孤独寂寞时,阅读可以消遣。高谈阔论时,知识可供装饰。处世行事时,知识意味着才干。”不论你是一个尚且在学海中遨游的学生,还是一个已经工作多年的成熟社会人,读书从来不该带有功利的意味,而是该从中看到世界万事万物的运行轨迹,也该从中体味生活和成长进步。'],
 'entity_answers': [[]],
 'yesno_answers': []}

输出样本1问题、答案、伪答案信息:

s_i = 1
sample_info(s_i)
question_id:   186573
question:      截至和截止区别
answers:       ['“截止”与“截至”的区别:一、名次解释:“截止”表示到某个时间停止,强调“停止”;“截至”表示停止于某个时间,强调“时间”。', '截止:某期限停止;截至:停止于某期限.', '截止:……止明确终止、结束含义表间点语词般其前面,截至:某间点没明确终止、结束含义延续延续表间点语词般其面。']
fake_answers:  ['“截止”与“截至”的区别:一、名次解释:“截止”表示到某个时间停止,强调“停止”;“截至”表示停止于某个时间,强调“时间”。']

其对应的预测结果:

json.loads(dev_res[s_i])
{'question_id': 186573,
 'question_type': 'DESCRIPTION',
 'answers': ['。1是完毕,不再继续。2是装束;打扮。3是收拾;处置。4是拘束.截止,表示到某个时间停止,强调"停止";截至,表示停止于某个时间,强调"时间"。'],
 'entity_answers': [[]],
 'yesno_answers': []}

从两个样本的预测结果来看,样本0的预测与原答案完全无关,样本1的预测与原答案相关,但是不完全相关。这说明在示例数据集上训练的样本尽管效果很差,但是已经学到了一点点有效的特征。

predict

调用模型对测试数据集进行预测。

# 需要把加载过得计算图重置一下
import tensorflow as tf
tf.reset_default_graph()
predict(args)
2020-03-28 00:24:33,181 - brc - INFO - Load data_set and vocab...
2020-03-28 00:24:33,534 - brc - INFO - Test set size: 100 questions.
2020-03-28 00:24:33,536 - brc - INFO - Converting text into ids...
2020-03-28 00:24:33,556 - brc - INFO - Restoring the model...
2020-03-28 00:24:36,974 - brc - INFO - Time to build graph: 3.412377119064331 s
2020-03-28 00:24:45,358 - brc - INFO - There are 4995603 parameters in the model

INFO:tensorflow:Restoring parameters from ../data/models/BIDAF

2020-03-28 00:24:46,307 - brc - INFO - Model restored from ../data/models/, with prefix BIDAF
2020-03-28 00:24:46,308 - brc - INFO - Predicting answers for test set...
2020-03-28 00:24:51,139 - brc - INFO - Saving test.predicted results to ../data/results/test.predicted.json

模型没有评估分数的输出,只是把预测结果保存在test.predicted.json文件中。

# 加载测试数据
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, test_files=args.test_files)
# 输出测试集样本属性
brc_data.test_set[0].keys()
2020-03-28 00:31:03,681 - brc - INFO - Test set size: 100 questions.

dict_keys(['documents', 'question', 'segmented_question', 'question_type', 'fact_or_opinion', 'question_id', 'question_tokens', 'passages'])

由输出可以看到,样本中没有answer相关的字段,这是因为测试样本不提供真实答案。

# 加载测试数据预测结果
result_file = os.path.join(args.result_dir, 'test.predicted' + '.json')
with open(result_file, 'r') as f:
    test_res = f.readlines()

查看下测试集样本问题:

s_i = 10
print('question:     ',brc_data.test_set[s_i]['question'])
question:      薛之谦偷吻郭雪芙哪期

其对应的预测结果:

json.loads(test_res[s_i])
{'question_id': 221584,
 'question_type': 'ENTITY',
 'answers': ['和郭雪芙是什么关系?薛之谦(JokerXue),1983年7月17日出生于上海,华语流行乐男歌手、影视演员、音乐制作人,毕业于格里昂酒店管理学院。'],
 'entity_answers': [[]],
 'yesno_answers': []}

查看下测试集样本问题:

s_i = 6
print('question:     ',brc_data.test_set[s_i]['question'])
question:      机动车违反禁止标线指示的扣几分

其对应的预测结果:

json.loads(test_res[s_i])
{'question_id': 221580,
 'question_type': 'ENTITY',
 'answers': ['违法行为:违反禁令标志、禁止标线指示的。\u3000\u30002013年新处罚标准:罚款100元,扣3分。\u3000\u3000原处罚标准:罚款100元,扣2分。'],
 'entity_answers': [[]],
 'yesno_answers': []}

由例子可见,示例数据上训练模型效果比较差,挑了两个看上去有点关系的预测,可以了解下预测结果形式。

参考文献:

猜你喜欢

转载自blog.csdn.net/wmq104/article/details/105156593
今日推荐