pytorch seq2seq模型示例

以下代码可以让你更加熟悉seq2seq模型机制

"""
    test
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

# 创建字典
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
num_dict = {n:i for i,n in enumerate(char_arr)}

# 网络参数
n_step = 5
n_hidden = 128
n_class = len(num_dict)
batch_size = len(seq_data)

# 准备数据
def make_batch(seq_data):
    input_batch, output_batch, target_batch =[], [], []

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + 'P' * (n_step-len(seq[i]))
        input = [num_dict[n] for n in seq[0]]
        ouput = [num_dict[n] for n in ('S'+ seq[1])]
        target = [num_dict[n] for n in (seq[1]) + 'E']

        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[ouput])
        target_batch.append(target)

    return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))

input_batch, output_batch, target_batch = make_batch(seq_data)


# 创建网络
class Seq2Seq(nn.Module):
    """
    要点:
    1.该网络包含一个encoder和一个decoder,使用的RNN的结构相同,最后使用全连接接预测结果
    2.RNN网络结构要熟知
    3.seq2seq的精髓:encoder层生成的参数作为decoder层的输入
    """
    def __init__(self):
        super().__init__()
        # 此处的input_size是每一个节点可接纳的状态,hidden_size是隐藏节点的维度
        self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.dec = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, enc_input, enc_hidden, dec_input):
        # RNN要求输入:(seq_len, batch_size, n_class),这里需要转置一下
        enc_input = enc_input.transpose(0,1)
        dec_input = dec_input.transpose(0,1)
        _, enc_states = self.enc(enc_input, enc_hidden)
        outputs, _ = self.dec(dec_input, enc_states)
        pred = self.fc(outputs)

        return pred


# training
model = Seq2Seq()
loss_fun = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5000):
    hidden = Variable(torch.zeros(1, batch_size, n_hidden))

    optimizer.zero_grad()
    pred = model(input_batch, hidden, output_batch)
    pred = pred.transpose(0, 1)
    loss = 0
    for i in range(len(seq_data)):
        temp = pred[i]
        tar = target_batch[i]
        loss +=  loss_fun(pred[i], target_batch[i])
    if (epoch + 1) % 1000 == 0:
        print('Epoch: %d   Cost: %f' % (epoch + 1, loss))
    loss.backward()
    optimizer.step()


# 测试
def translate(word):
    input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])
    # hidden 形状 (1, 1, n_class)
    hidden = Variable(torch.zeros(1, 1, n_hidden))
    # output 形状(6,1, n_class)
    output = model(input_batch, hidden, output_batch)
    predict = output.data.max(2, keepdim=True)[1]
    decoded = [char_arr[i] for i in predict]
    end = decoded.index('E')
    translated = ''.join(decoded[:end])

    return translated.replace('P', '')

print('girl ->', translate('girl'))

参考:https://blog.csdn.net/weixin_43632501/article/details/98525673

猜你喜欢

转载自www.cnblogs.com/demo-deng/p/11811090.html