Pytorch+LSTM+Attention 实现 Seq2Seq

# !/usr/bin/env Python3
# -*- coding: utf-8 -*-
# @version: v1.0
# @Author   : Meng Li
# @contact: [email protected]
# @FILE     : torch_seq2seq.py
# @Time     : 2022/6/8 11:11
# @Software : PyCharm
# @site:
# @Description : 将Seq2Seq网络采用编码器和解码器两个类进行融合
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchsummary
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class my_dataset(Dataset):
    def __init__(self, enc_input, dec_input, dec_output):
        super().__init__()
        self.enc_input = enc_input
        self.dec_input = dec_input
        self.dec_output = dec_output

    def __getitem__(self, index):
        return self.enc_input[index], self.dec_input[index], self.dec_output[index]

    def __len__(self):
        return self.enc_input.size(0)


class Encoder(nn.Module):
    def __init__(self, in_features, hidden_size):
        super().__init__()
        self.in_features = in_features
        self.hidden_size = hidden_size
        self.encoder = nn.LSTM(input_size=in_features, hidden_size=hidden_size, dropout=0.5, num_layers=1)  # encoder

    def forward(self, enc_input):
        enc_input = enc_input.to(device)
        seq_len, batch_size, embedding_size = enc_input.size()
        h_0 = torch.rand(1, batch_size, self.hidden_size).to(device)
        c_0 = torch.rand(1, batch_size, self.hidden_size).to(device)
        # en_ht:[num_layers * num_directions,Batch_size,hidden_size]
        encode_output, (encode_ht, decode_ht) = self.encoder(enc_input, (h_0, c_0))
        return encode_output, (encode_ht, decode_ht)


class Decoder(nn.Module):
    def __init__(self, in_features, enc_hid_size, dec_hid_size, Attn):
        super().__init__()
        self.in_features = in_features
        self.Attn = Attn
        self.enc_hid_size = enc_hid_size
        self.dec_hid_size = dec_hid_size
        self.crition = nn.CrossEntropyLoss()
        self.fc = nn.Linear(in_features + enc_hid_size, in_features)
        self.decoder = nn.LSTM(input_size=in_features, hidden_size=dec_hid_size, dropout=0.5, num_layers=1)  # encoder

    def forward(self, enc_output, dec_input, s):
        # s : [1, Batch_size , enc_hid_size ] s表示解码器的某一个隐含层的输出
        # enc_output : [seq_len, Batch_size,enc_hid_size]   对应于整个解码器的某一个输入
        # dec_input : [1, Batch_size, embed_size]  对应于解码器的某一个输入
        dec_input = dec_input.unsqueeze(1)
        seq_len, Batch_size, embed_size = enc_output.size()
        atten = self.Attn(s, enc_output)  # atten : [Batch_size, seq_len]

        atten = atten.unsqueeze(2)  # atten : [Batch_size, seq_len, 1]
        atten = atten.transpose(1, 2)  # atten : [Batch_size, 1, seq_len]
        enc_output = enc_output.transpose(0, 1)
        ret = torch.bmm(atten, enc_output)  # ret : [Batch_size, 1, enc_hid_size]
        ret = ret.transpose(0, 1)  # ret : [1, Batch_size, enc_hid_size]
        dec_input = dec_input.transpose(0, 1)  # dec_input : [1, Batch_size, embed_size]
        dec_input_t = torch.cat((ret, dec_input), dim=2)  # dec_input_t : [1, Batch_size, enc_hid_size+embed_size]
        dec_input_tt = self.fc(dec_input_t)  # dec_input_tt : [1, Batch_size, embed_size]
        c0 = torch.zeros(1, Batch_size, embed_size)
        s = s.to(device)
        c0 = c0.to(device)
        de_output, (s, _) = self.decoder(dec_input_tt, (s, c0)) # de_output:[1, Batch_size, dec_hid_size]
        return de_output, s


class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.fc1 = torch.nn.Linear(dec_hid_dim + enc_hid_dim, dec_hid_dim)
        self.fc2 = torch.nn.Linear(dec_hid_dim, 1)

    def forward(self, s, enc_output):
        # 将解码器的输出S和编码器的隐含层输出求相似性
        # s: [1, Batch_size, dec_hid_size]
        # enc_output: [seq_len, Batch_size, enc_hid_size ]
        seq_len, Batch_size, enc_hid_size = enc_output.size()
        # s = s.unsqueeze(1)  # s: [Batch_size,1, dec_hid_size]
        s = s.repeat(seq_len, 1, 1)  # s: [seq_len, Batch_size, dec_hid_size]
        a = torch.tanh(torch.cat((s, enc_output), 2))  # a: [Batch_size, seq_len, dec_hid_size + enc_hid_size ]
        a = self.fc1(a)  # a :  [Batch_size, seq_len, dec_hid_dim]
        a = self.fc2(a)  # a :  [Batch_size, seq_len, 1]
        a = a.squeeze(2)  # a :  [Batch_size, seq_len]
        return F.softmax(a, dim=1).transpose(0, 1)  # softmax 只进行归一化,不改变张量的维度


class Seq2seq(nn.Module):
    def __init__(self, encoder, decoder, in_features, hidden_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.in_features = in_features
        self.hidden_size = hidden_size
        self.fc = nn.Linear(hidden_size, in_features)
        self.crition = nn.CrossEntropyLoss()

    def forward(self, enc_input, dec_input, dec_output):
        enc_input = enc_input.to(device)
        dec_input = dec_input.to(device)
        dec_output = dec_output.to(device)
        enc_input = enc_input.permute(1, 0, 2)  # [seq_len,Batch_size,embedding_size]
        dec_input = dec_input.permute(1, 0, 2)  # [seq_len,Batch_size,embedding_size]
        seq_len, Batch_size, embedding_size = dec_input.size()
        outputs = torch.zeros(seq_len, Batch_size, self.hidden_size)  # 初始化一个张量,用来存储解码器每一步的输出
        target_len, _, _ = dec_input.size()
        # 首先通过编码器的最后一步输出得到 解码器的第一个隐含层 , 以及将编码器的所有的输出层作为后续提取注意力
        enc_output, (s, _) = self.encoder(enc_input)  # s : [1, Batch_size, enc_hid_size ]
        for i in range(1, target_len):
            dec_output_i, s = self.decoder(enc_output, dec_input[i, :, :], s)
            outputs[i] = dec_output_i
        # output:[seq_len,Batch_size,hidden_size]
        outputs = outputs.to(device)
        output = self.fc(outputs)
        output = output.permute(1, 0, 2)
        loss = 0
        for i in range(len(output)):  # 对seq的每一个输出进行二分类损失计算
            loss += self.crition(output[i], dec_output[i])
        return output, loss


def make_data(seq_data):
    enc_input_all, dec_input_all, dec_output_all = [], [], []
    vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]
    word2idx = {j: i for i, j in enumerate(vocab)}
    V = np.max([len(j) for i in seq_data for j in i])  # 求最长元素的长度
    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + '?' * (V - len(seq[i]))  # 'man??', 'women'

        enc_input = [word2idx[n] for n in (seq[0] + 'E')]
        dec_input = [word2idx[i] for i in [i for i in len(enc_input) * '?']]
        dec_output = [word2idx[n] for n in (seq[1] + 'E')]

        enc_input_all.append(np.eye(len(vocab))[enc_input])
        dec_input_all.append(np.eye(len(vocab))[dec_input])
        dec_output_all.append(dec_output)  # not one-hot

    # make tensor
    return torch.Tensor(enc_input_all), torch.Tensor(dec_input_all), torch.LongTensor(dec_output_all)


def translate(word):
    vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]
    idx2word = {i: j for i, j in enumerate(vocab)}
    V = 5
    x, y, z = make_data([[word, "?" * V]])
    if not os.path.exists("translate.pt"):
        train()
    net = torch.load("translate.pt")
    pre, loss = net(x, y, z)
    pre = torch.argmax(pre, 2)[0]
    pre_word = [idx2word[i] for i in pre.numpy()]
    pre_word = "".join([i.replace("?", "") for i in pre_word])
    print(word, "->  ", pre_word[:pre_word.index('E')])


def train():
    vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]
    word2idx = {j: i for i, j in enumerate(vocab)}
    idx2word = {i: j for i, j in enumerate(vocab)}
    seq_data = [['man', '男人'], ['black', '黑色'], ['king', '国王'], ['girl', '女孩'], ['up', '上'],
                ['high', '高'], ['women', '女人'], ['white', '白色'], ['boy', '男孩'], ['down', '下'], ['low', '低'],
                ['queen', '女王']]
    enc_input, dec_input, dec_output = make_data(seq_data)
    batch_size = 3
    in_features = len(vocab)
    hidden_size = 128

    train_data = my_dataset(enc_input, dec_input, dec_output)
    train_iter = DataLoader(train_data, batch_size, shuffle=True)

    atten = Attention(enc_hid_dim=hidden_size, dec_hid_dim=hidden_size)
    encoder = Encoder(in_features, hidden_size).to(device)
    decoder = Decoder(in_features, hidden_size, hidden_size, atten).to(device)
    net = Seq2seq(encoder, decoder, in_features, hidden_size).to(device)
    learning_rate = 0.001
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    loss = 0

    for i in range(10000):
        for en_input, de_input, de_output in train_iter:
            output, loss = net(en_input, de_input, de_output)
            pre = torch.argmax(output, 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if i % 100 == 0:
            print("step {0} loss {1}".format(i, loss))
    torch.save(net, "translate.pt")


if __name__ == '__main__':
    before_test = ['man', 'black', 'king', 'girl', 'up', 'high', 'women', 'white', 'boy', 'down', 'low', 'queen',
                   'mman', 'woman']
    # [translate(i) for i in before_test]
    train()

基于LSTM+Attention 的Seq2Seq模型,训练难以收敛,大概迭代10000次,模型的损失loss仍大于1 , 后面我采用双向LSTM对模型进行了改进,但是模型收敛效果仍然不够理想。

猜你喜欢

转载自blog.csdn.net/linxizi0622/article/details/125303661