seq2seq的实现方式(1)

应用场景

seq2seq是自然语言处理应用中的常用模型,一般的机器翻译,文本摘要,对话生成(虽然之前实现过基于语言模型+关键词的生成方式,但这才是正道),文本摘要等任务。更高级的模型也是从基础的模型进行迭代的模型架构相对统一。

其具体的模型原理就不讲了,有很多博客已经有很好的说明,在这里只是趁着周末更新一下seq2seq在机器翻译方面的实验,更新上来供同行们参考。

嗯,seq2seq有几种模式:
(1)最简单的一种是Encoder的隐层向量复制后直接作为decoder的输入,也就是decoder对不需要序列输入。
(2)在一个是Encoder的隐层向量作为decoder的初始化,并且decoder的有输入序列,并且和输出序列错位,用于启发。
(3)就是把(1)和(2)结合起来,即要参考Decoder输入序列,又要参考Encoder的最后的隐层向量,为启发获取更多的信息。
(4)因为在翻译每一个词的时候,输入端各个词的贡献其实是不一样的,所以用Encoder的最后隐层没有多样性,所以改用attention替换(3)中的Encoder隐层向量。
几种方式一脉相承,逐步深化。

这里实现了(2)的方法。

    def build_model(self):
        
        encoder_input = layers.Input(shape=(self.input_seq_len,))
        encoder_embeding = layers.Embedding(input_dim=len(self.en_word_id_dict),
                                            output_dim=self.encode_embeding_len,
                                            mask_zero=True
                                            )(encoder_input)
        encoder_lstm, state_h, state_c = layers.LSTM(units=self.encode_embeding_len,
                                                     return_state=True)(encoder_embeding)

        encoder_state = [state_h, state_c]

        decoder_input = layers.Input(shape=(self.output_seq_len,))
        decoder_embeding = layers.Embedding(input_dim=len(self.ch_word_id_dict),
                                            output_dim=self.decode_embeding_len,
                                            mask_zero=True
                                            )(decoder_input)
        decoder_lstm, _, _ = layers.LSTM(units=self.encode_embeding_len,
                                         return_state=True,
                                         return_sequences=True)(decoder_embeding, initial_state=encoder_state)
        decoder_out = layers.Dense(len(self.ch_word_id_dict), activation="softmax")(decoder_lstm)

        model = Model([encoder_input, decoder_input], decoder_out)
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
        # model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy')
        model.summary()
        return model


猜你喜欢

转载自blog.csdn.net/cyinfi/article/details/88375608
今日推荐