Relational recurrent neural networks 论文代码阅读及实现例子

论文链接:

https://arxiv.org/abs/1806.01822

源起:

       将个体的经历和轨迹看作是一种记忆就可以定义记忆网络结构,其作用就是从这种经历出发推断可能的状态及导向性操作。(为达到某一状态需要进行的操作)

       最简单的一种例子就是RNN,从seq2seq attention摘要的角度,context vector可以看作是对encoder端的解码相关记忆概括,相应叙述见前文:End-To-End Memory Networks 论文阅读。那篇文章已经提出一种模仿lstm记忆结构的记忆网络,在那里相应的记忆可以看成是一种来自外部的信息,即将其看成带context候选信息的QA问题,其记忆部分正对应context。类似的解决方案还有使用query直接从context中进行抽取的结构,如:BI-DIRECTIONAL ATTENSION FLOW FOR MACHINE COMPREHENSION 中的问题。

        现在让我们换一个场景,比如多轮QA,从对话主体的角度,记忆应该对应于整个对话产生的上下文,而这上下文本身是对话主体产生的,而不是外部给定的,是一种动态记忆更新结构,本文就可以看成更新个体经历记忆的一种网络结构。

        从基本神经网络来看,对这一类问题应该已经有可能的解决方案了,比如RNN中的隐状态应该就是正解。从概念上来说是对的,但是从网络容量及设计上还有一些问题,如对lstm相应的“记忆体”是一个向量,而且c与h又有其具体的意义,其个体统计意义更多的是单步的信息,下一步又仅依赖于上一步,我们需要一个“盒子”用来积攒往来所有的记忆(隐状态,或等价的输入),而不是对单个step的记忆更新来搞筛选。

        所以很多论文的出发点就可以看成将lstm中的隐状态向量“改成”矩阵。非常推荐看一下下面这篇文章:Hybrid computing using a neural network with dynamic external memory

        看完了这篇看其他的就不会有难度了,作者构造了记忆矩阵,并将其类比为内存块,还设计了内存的“读写头”(而且在这个过程还会使用eraser擦掉没有用的记忆),把网络看成cpu,很多学计算机的应该会很喜欢这种精细的设计。精细的设计在对具体子块进行验证时会有更好的意义。就是有一些复杂。更难得的是,优化还是强化学习。

        相较上面那篇文章本文就简单得多,一句话来概括:Attention Is All You Need。由于从问题出发不需要记住位置,所用到的就是 multi-head self-attention。由这种网络结构直接构造记忆矩阵如下:

            

其单步更新方法如下:

            

将输入x直接“压缩”fuse到记忆中。

       记忆的构造与更新相对简单,重要的是如何对于一个记忆矩阵定义诸如lstm那种带gate的记忆筛选“读取”方式,以进行动态存取与信息整合。文中给出的公式就比较复杂了,而且只有结合作者提供的代码才能了解细节:

            

           实际在看代码之前,模型的整体图示更重要:


        公式中唯一没有显式给出的就是函数g的定义,这也是个人认为最重要的,其用mlp做了row element wisesum

       下面为了方便还是copy一下实现的代码:

from sonnet.python.modules import basic
from sonnet.python.modules import layer_norm
from sonnet.python.modules import rnn_core
from sonnet.python.modules.nets import mlp

import tensorflow as tf

class RelationalMemory(rnn_core.RNNCore):
    def __init__(self, mem_slots = 10, head_size = 10, num_heads = 3, num_blocks = 1,
                 forget_bias = 1.0, input_bias = 0.0, gate_style = "unit",
                 attension_mlp_layers = 2, key_size = None, name = "relational_memory"):
        super(RelationalMemory, self).__init__(name="name")
        self._mem_slots = mem_slots

        # multi head size
        self._head_size = head_size
        self._num_heads = num_heads

        self._mem_size = self._head_size * self._num_heads

        if num_blocks < 1:
            raise ValueError("num_blocks must be >= 1, Got: {}.".format(num_blocks))
        self._num_blocks = num_blocks
        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ["unit", "memory", None]:
            raise ValueError(
                r"gate_style must be one of ['unit', 'memory', None] Got {}".format(gate_style)
            )
        self._gate_style = gate_style
        if attension_mlp_layers < 1:
            raise ValueError("attension_mlp_layers must be >= 1, Got: {}".format(
                attension_mlp_layers
            ))

        self._attention_mlp_layers = attension_mlp_layers
        # this size may be the size compatible with column num of memory
        self._key_size = key_size if key_size else self._head_size
    # init memory matrix
    def initial_state(self, batch_size, trainable = False):
        '''
        # [batch, mem_slots, mem_slots]
        init_state = tf.eye(self._mem_slots, batch_shape=[batch_size])
        if self._mem_size > self._mem_slots:
            difference = self._mem_size - self._mem_slots
            pad = tf.zeros((batch_size, self._mem_slots, difference))
            init_state = tf.concat([init_state, pad], -1)
        elif self._mem_size < self._mem_slots:
            init_state = init_state[:, :, :self._mem_size]
        return init_state
        '''
        init_state = tf.eye(self._mem_slots, self._mem_size, batch_shape=[batch_size])
        return init_state

    def _multihead_attention(self, memory):
        key_size = self._key_size
        value_size = self._head_size

        qkv_size = 2 * key_size + value_size
        total_size = qkv_size * self._num_heads

        qkv = basic.BatchApply(basic.Linear(total_size))(memory)
        qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv)

        mem_slots = memory.get_shape().as_list()[1]

        qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads,
                                          qkv_size])(qkv)
        qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
        q, k, v = tf.split(qkv_transpose, [key_size, key_size, key_size], -1)
        q *= qkv_size ** -0.5
        dot_product = tf.matmul(q, k, transpose_b=True)
        weights = tf.nn.softmax(dot_product)

        output = tf.matmul(weights, v)
        output_transpose = tf.transpose(output, [0, 2, 1, 3])

        new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose)
        return new_memory

    @property
    def state_size(self):
        return tf.TensorShape([self._mem_slots, self._mem_size])

    @property
    def output_size(self):
        return tf.TensorShape(self._mem_slots * self._mem_size)

    def _calculate_gate_size(self):
        if self._gate_style == "unit":
            return self._mem_size
        elif self._gate_style == "memory":
            return 1
        else:
            return 0

    def _create_gates(self, inputs, memory):
        num_gates = 2 * self._calculate_gate_size()
        memory = tf.tanh(memory)

        # shape 2
        inputs = basic.BatchFlatten()(inputs)
        gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs)
        # shape 3
        gate_inputs = tf.expand_dims(gate_inputs, axis=1)
        gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory)

        # broadcast add to every row of memory
        gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2)
        input_gate, forget_gate = gates

        input_gate = tf.sigmoid(input_gate + self._input_bias)
        forget_gate = tf.sigmoid(forget_gate + self._forget_bias)

        return input_gate, forget_gate

    def _attend_over_memory(self, memory):
        attention_mlp = basic.BatchApply(
            mlp.MLP([self._mem_size] * self._attention_mlp_layers)
        )
        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory)
            memory = basic.BatchApply(layer_norm.LayerNorm())(
                memory + attended_memory
            )
            memory = basic.BatchApply(layer_norm.LayerNorm())(
                attention_mlp(memory) + memory
            )
        return memory

    def _build(self, inputs, memory, treat_input_as_matrix = False):
        if treat_input_as_matrix:
            inputs = basic.BatchFlatten(preserve_dims=2)(inputs)
            inputs_reshape =basic.BatchApply(
                basic.Linear(self._mem_size), n_dims=2
            )(inputs)
        else:
            inputs = basic.BatchFlatten()(inputs)
            inputs = basic.Linear(self._mem_size)(inputs)
            inputs_reshape = tf.expand_dims(inputs, 1)

        memory_plus_input = tf.concat([memory, inputs_reshape], axis=1)
        next_memory = self._attend_over_memory(memory_plus_input)

        n = inputs_reshape.get_shape().as_list()[1]
        next_memory = next_memory[:,:-n,:]

        if self._gate_style == "unit" or self._gate_style == "memory":
            self._input_gate, self._forget_gate = self._create_gates(
                inputs_reshape, memory
            )
            next_memory = self._input_gate * tf.tanh(next_memory)
            next_memory += self._forget_gate * memory

        output = basic.BatchFlatten()(next_memory)
        return output, next_memory

    @property
    def input_gate(self):
        self._ensure_is_connected()
        return self._input_gate

    @property
    def forget_gate(self):
        self._ensure_is_connected()
        return self._forget_gate


if __name__ == "__main__":
    pass

        先说一下代码风格,sonnet个人认为最方便的就是,batch级别的操作函数,从此再不用tf.map_fn。其继承 rnn_core.RNNCore 之后实现的_build基本对应tensorflowRNNcall函数。另外不禁想吐糟一下,如果是我写self-attention很可能是Q K V分别定义weight之后运算,而作者用的是一个线性变换后split的方式,很好的继承了tensorflow rnn中诸组成部分的写作风格,而且最后记忆更新也是[M;x]整个送进去压缩之后slice. 

        了解了上述网络构造后可以用一个简单的例子试一下,比如用记忆网络估计递归运算,这也是原文中最简单的实验了(感兴趣可以看learning to execute)有关递归结构神经网络的例子,感兴趣可以看TensorFlow Fold或者那篇博文 TensorFlow Fold 初探(一)——TreeLstm情感分类。这里完全用相同的sum函数数据来看一下效果。代码如下:

from model.study import RelationalMemory
from sonnet.python.modules import basic
import tensorflow as tf
import random
import numpy as np
import os

max_seq_len = 10

def random_example(fn, length = max_seq_len):
    length = random.randrange(1, length)
    data = [random.uniform(0,1) for _ in range(length)]
    result = fn(data)
    return data, result

def random_generator(batch_num = 2, fn = sum):
    while True:
        req_x, req_mask, req_y = [], [], []

        for _ in range(batch_num):
            data, result = random_example(fn)
            req_x.append(data)
            req_mask.append(len(data))
            req_y.append(result)

        req_x, req_mask, req_y = map(np.array ,[req_x, req_mask, req_y])
        yield req_x, req_mask, req_y

# single model without sequence embedding
class RetionalModel(object):
    def __init__(self, max_seq_len = max_seq_len, dnn_size = 100,
                 epsilon = 1.0):
        self.max_seq_len = max_seq_len
        self.dnn_size = dnn_size
        # use epsilon to identify accurate rate
        self.epsilon = epsilon

        self.input = tf.placeholder(tf.float32, [None, max_seq_len])
        self.input_mask = tf.placeholder(tf.int32, [None])
        self.y = tf.placeholder(tf.float32, [None])

        self.model_construct()

    def model_construct(self):
        relationalMemoryCell = RelationalMemory()
        outputs, state = tf.nn.dynamic_rnn(cell=relationalMemoryCell,
                          inputs=tf.expand_dims(self.input, axis=-1),
                          sequence_length=self.input_mask,
                                           dtype=tf.float32)

        flatten_outputs = basic.BatchFlatten()(outputs)
        h0 = tf.layers.dense(inputs=flatten_outputs, units=self.dnn_size, name="h0")
        self.prediction = tf.squeeze(tf.layers.dense(inputs=h0, units=1), name="prediction")

        self.accuracy = tf.reduce_mean(tf.cast((self.epsilon - tf.abs(self.prediction - self.y)) > 0, tf.float32))
        self.loss = tf.losses.mean_squared_error(labels=self.y, predictions=self.prediction)
        self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.loss)


def simple_alg_seq_test():
    train_gen = random_generator(batch_num=128)
    valid_gen = random_generator(batch_num=64)

    model = RetionalModel()
    step = 0

    saver = tf.train.Saver()
    with tf.Session() as sess:
        if os.path.exists(r"E:\Coding\python\retionalSonnetStudy\model.ckpt.index"):
            print("restore exists")
            saver.restore(sess, save_path=r"E:\Coding\python\retionalSonnetStudy\model.ckpt")
        else:
            print("init global")
            sess.run(tf.global_variables_initializer())

        #sess.run(tf.global_variables_initializer())

        while True:
            req_x, req_mask, req_y = train_gen.__next__()
            req_x_pad = np.zeros(shape=[128, max_seq_len])
            for e_idx, ele in enumerate(req_x):
                for c_idx, inner_ele in enumerate(ele):
                    req_x_pad[e_idx][c_idx] = inner_ele
            _, loss, train_acc = sess.run([model.train_op, model.loss, model.accuracy],
                     feed_dict={
                         model.input: req_x_pad,
                         model.input_mask: req_mask,
                         model.y : req_y
                     })
            step += 1
            if step % 5 == 0:
                print("train, loss :{} acc : {}".format(loss, train_acc))

                req_x, req_mask, req_y = valid_gen.__next__()
                req_x_pad = np.zeros(shape=[64, max_seq_len])
                for e_idx, ele in enumerate(req_x):
                    for c_idx, inner_ele in enumerate(ele):
                        req_x_pad[e_idx][c_idx] = inner_ele
                loss, valid_acc = sess.run([model.loss, model.accuracy],
                                              feed_dict={
                                                  model.input: req_x_pad,
                                                  model.input_mask: req_mask,
                                                  model.y : req_y
                                              })
                print("valid loss: {}, acc: {}".format(loss, valid_acc))
                saver.save(sess, save_path=r"E:\Coding\python\retionalSonnetStudy\model.ckpt")

if __name__ == "__main__":
    simple_alg_seq_test()
结果如下:

restore exists
train, loss :0.24382832646369934 acc : 0.953125
valid loss: 0.2279301881790161, acc: 0.984375
train, loss :0.15681743621826172 acc : 0.9921875
valid loss: 0.10683506727218628, acc: 1.0
train, loss :0.09039808064699173 acc : 1.0
valid loss: 0.12752145528793335, acc: 1.0
相较于TensorFold Fold中的递归函数,可以吐槽一下这个记忆网络实现的长度,不过其记忆能力是具有一般性的。
        上述记忆网络结构在近期的强化论文中也有使用,见:
Relational Deep Reinforcement Learning
论文链接:https://arxiv.org/abs/1806.01830
其基本网络结构见下图:

            

        关键的relational module与本文记忆结构基本相同,只不过思路大致上是先提取图像特征,之后抽象出值函数及策略函数后使用强化学习方式进行优化。

    其一个例子是应用于星际II小游戏,还没有看到开源数据集。。。。。。满满的怨念。。。。。。


猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/80645843