官网实例详解4.4(babi_rnn.py)-keras学习笔记四

基于一个故事和一个问题来训练两个循环神经网络。

Keras实例目录

代码注释

'''Trains two recurrent neural networks based upon a story and a question.
基于一个故事和一个问题来训练两个循环神经网络。
The resulting merged vector is then queried to answer a range of bAbI tasks.
然后,对得到的合并向量进行查询,以回答一系列的Babi任务。
The results are comparable to those for an LSTM model provided in Weston et al.:
其结果与Weston等人提供的LSTM模型的结果相当。
"Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks"
面向AI完整问答:一组必备的任务.
http://arxiv.org/abs/1502.05698

Task Number                  | FB LSTM Baseline | Keras QA
---                          | ---              | ---
QA1 - Single Supporting Fact | 50               | 100.0
QA2 - Two Supporting Facts   | 20               | 50.0
QA3 - Three Supporting Facts | 20               | 20.5
QA4 - Two Arg. Relations     | 61               | 62.9
QA5 - Three Arg. Relations   | 70               | 61.9
QA6 - yes/No Questions       | 48               | 50.7
QA7 - Counting               | 49               | 78.9
QA8 - Lists/Sets             | 45               | 77.2
QA9 - Simple Negation        | 64               | 64.0
QA10 - Indefinite Knowledge  | 44               | 47.7
QA11 - Basic Coreference     | 72               | 74.9
QA12 - Conjunction           | 74               | 76.4
QA13 - Compound Coreference  | 94               | 94.4
QA14 - Time Reasoning        | 27               | 34.8
QA15 - Basic Deduction       | 21               | 32.4
QA16 - Basic Induction       | 23               | 50.6
QA17 - Positional Reasoning  | 51               | 49.1
QA18 - Size Reasoning        | 52               | 90.8
QA19 - Path Finding          | 8                | 9.0
QA20 - Agent's Motivations   | 91               | 90.7

For the resources related to the bAbI project, refer to:
与bAbI项目相关的资源,请参见:
https://research.facebook.com/researchers/1543934539189348

# Notes
注意
- With default word, sentence, and query vector sizes, the GRU model achieves:
- 在默认单词、句子和查询向量大小的情况下,GRU模型实现:
  - 100% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
  - 100%测试准确率在问题1,20个周期(2秒/周期,基于CPU)
  - 50% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
  - 50%测试准确率在问题2,20个周期(16秒/周期,基于CPU)
In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline.
相比之下,脸谱论文达到LSTM的50%和20%。
- The task does not traditionally parse the question separately. This likely
improves accuracy and is a good example of merging two RNNs.
- 任务并不是传统地单独分析问题。这可能提高精度,是合并两个RNNs的一个很好的例子。

- The word vector embeddings are not shared between the story and question RNNs.
- 词嵌入不在故事和问题RNNs之间共享。
- See how the accuracy changes given 10,000 training samples (en-10k) instead
of only 1000. 1000 was used in order to be comparable to the original paper.
- 注意提供样本10,000替换1000后准确率咱们改变的。1000是为了与原论文张相比。

- Experiment with GRU, LSTM, and JZS1-3 as they give subtly different results.
- 使用GRU、LSTM和JZS1-3进行实验,因为它们给出了细微不同的结果。

- The length and noise (i.e. 'useless' story components) impact the ability for
LSTMs / GRUs to provide the correct answer. Given only the supporting facts,
these RNNs can achieve 100% accuracy on many tasks. Memory networks and neural
networks that use attentional processes can efficiently search through this
noise to find the relevant statements, improving performance substantially.
This becomes especially obvious on QA2 and QA3, both far longer than QA1.
- 长度和噪声(即“无用”的故事成分)影响LSTMS/GRU提供正确答案的能力。仅给出支持事实,这些RNN可以在许多任
务上达到100%的精度。使用注意过程的记忆网络和神经网络可以有效地通过这种噪声搜索来寻找相关的陈述,从而显著
地提高性能。这在QA2和QA3上变得特别明显,两者都远胜于QA1。
'''

from __future__ import print_function
from functools import reduce
import re
import tarfile

import numpy as np

from keras.utils.data_utils import get_file
from keras.layers.embeddings import Embedding
from keras import layers
from keras.layers import recurrent
from keras.models import Model
from keras.preprocessing.sequence import pad_sequences


def tokenize(sent):
    '''Return the tokens of a sentence including punctuation.
    返回包含标点符号的句子的分词。

    >>> tokenize('Bob dropped the apple. Where is the apple?')
    ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
    '''
    return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()]


def parse_stories(lines, only_supporting=False):
    '''Parse stories provided in the bAbi tasks format
    在BABI任务格式中解析故事

    If only_supporting is ,
    only the sentences that support the answer are kept.
    如果仅仅支持,则只保留支持答案的句子。
    '''
    data = []
    story = []
    for line in lines:
        line = line.decode('utf-8').strip()
        nid, line = line.split(' ', 1)
        nid = int(nid)
        if nid == 1:
            story = []
        if '\t' in line:
            q, a, supporting = line.split('\t')
            q = tokenize(q)
            substory = None
            if only_supporting:
                # Only select the related substory
                # 只选择相关的小故事
                supporting = map(int, supporting.split())
                substory = [story[i - 1] for i in supporting]
            else:
                # Provide all the substories
                # 提供所有的小故事
                substory = [x for x in story if x]
            data.append((substory, q, a))
            story.append('')
        else:
            sent = tokenize(line)
            story.append(sent)
    return data


def get_stories(f, only_supporting=False, max_length=None):
    '''Given a file name, read the file, retrieve the stories,
    and then convert the sentences into a single story.
    给定文件名,读取文件,检索故事,然后将句子转换成一个故事。

    If max_length is supplied,
    any stories longer than max_length tokens will be discarded.
    如果提供了max_length,则将丢弃比max_length标记长的故事。
    '''
    data = parse_stories(f.readlines(), only_supporting=only_supporting)
    flatten = lambda data: reduce(lambda x, y: x + y, data)
    data = [(flatten(story), q, answer) for story, q, answer in data if not max_length or len(flatten(story)) < max_length]
    return data


def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
    xs = []
    xqs = []
    ys = []
    for story, query, answer in data:
        x = [word_idx[w] for w in story]
        xq = [word_idx[w] for w in query]
        # let's not forget that index 0 is reserved
        # 索引0是预留的。
        y = np.zeros(len(word_idx) + 1)
        y[word_idx[answer]] = 1
        xs.append(x)
        xqs.append(xq)
        ys.append(y)
    return pad_sequences(xs, maxlen=story_maxlen), pad_sequences(xqs, maxlen=query_maxlen), np.array(ys)

RNN = recurrent.LSTM
EMBED_HIDDEN_SIZE = 50
SENT_HIDDEN_SIZE = 100
QUERY_HIDDEN_SIZE = 100
BATCH_SIZE = 32
EPOCHS = 40
print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN,
                                                           EMBED_HIDDEN_SIZE,
                                                           SENT_HIDDEN_SIZE,
                                                           QUERY_HIDDEN_SIZE))

try:
    path = get_file('babi-tasks-v1-2.tar.gz', origin='https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz')
except:
    print('Error downloading dataset, please download it manually:\n'
          '$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz\n'
          '$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
    raise

# Default QA1 with 1000 samples
# 默认QA1有1000个样本
# challenge = 'tasks_1-20_v1-2/en/qa1_single-supporting-fact_{}.txt'
# QA1 with 10,000 samples
# QA1有10,000个样本
# challenge = 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt'
# QA2 with 1000 samples
# QA2有1000个样本
challenge = 'tasks_1-20_v1-2/en/qa2_two-supporting-facts_{}.txt'
# QA2 with 10,000 samples
# QA2有10,000个样本
# challenge = 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt'
with tarfile.open(path) as tar:
    train = get_stories(tar.extractfile(challenge.format('train')))
    test = get_stories(tar.extractfile(challenge.format('test')))

vocab = set()
for story, q, answer in train + test:
    vocab |= set(story + q + [answer])
vocab = sorted(vocab)

# Reserve 0 for masking via pad_sequences
# 通过pad_sequences掩膜保留0
vocab_size = len(vocab) + 1
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
story_maxlen = max(map(len, (x for x, _, _ in train + test)))
query_maxlen = max(map(len, (x for _, x, _ in train + test)))

x, xq, y = vectorize_stories(train, word_idx, story_maxlen, query_maxlen)
tx, txq, ty = vectorize_stories(test, word_idx, story_maxlen, query_maxlen)

print('vocab = {}'.format(vocab))
print('x.shape = {}'.format(x.shape))
print('xq.shape = {}'.format(xq.shape))
print('y.shape = {}'.format(y.shape))
print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen))

print('Build model...')

sentence = layers.Input(shape=(story_maxlen,), dtype='int32')
encoded_sentence = layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(sentence)
encoded_sentence = layers.Dropout(0.3)(encoded_sentence)

question = layers.Input(shape=(query_maxlen,), dtype='int32')
encoded_question = layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(question)
encoded_question = layers.Dropout(0.3)(encoded_question)
encoded_question = RNN(EMBED_HIDDEN_SIZE)(encoded_question)
encoded_question = layers.RepeatVector(story_maxlen)(encoded_question)

merged = layers.add([encoded_sentence, encoded_question])
merged = RNN(EMBED_HIDDEN_SIZE)(merged)
merged = layers.Dropout(0.3)(merged)
preds = layers.Dense(vocab_size, activation='softmax')(merged)

model = Model([sentence, question], preds)
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

print('Training')
model.fit([x, xq], y,
          batch_size=BATCH_SIZE,
          epochs=EPOCHS,
          validation_split=0.05)
loss, acc = model.evaluate([tx, txq], ty,
                           batch_size=BATCH_SIZE)
print('Test loss / test accuracy = {:.4f} / {:.4f}'.format(loss, acc))

代码执行

 

Keras详细介绍

英文:https://keras.io/

中文:http://keras-cn.readthedocs.io/en/latest/

实例下载

https://github.com/keras-team/keras

https://github.com/keras-team/keras/tree/master/examples

完整项目下载

方便没积分童鞋,请加企鹅452205574,共享文件夹。

包括:代码、数据集合(图片)、已生成model、安装库文件等。

 

Keras实例目录

代码注释

代码执行

 

Keras详细介绍

英文:https://keras.io/

中文:http://keras-cn.readthedocs.io/en/latest/

实例下载

https://github.com/keras-team/keras

https://github.com/keras-team/keras/tree/master/examples

完整项目下载

方便没积分童鞋,请加企鹅452205574,共享文件夹。

包括:代码、数据集合(图片)、已生成model、安装库文件等。

 

猜你喜欢

转载自blog.csdn.net/wyx100/article/details/80789297
今日推荐