【pytorch】Transformer的Pytorch实现-简单翻译


论文地址: https://arxiv.org/pdf/1706.03762.pdf

代码参考:https://wmathor.com/index.php/archives/1455/

备注:该代码中对Transformer模型构建时均不含有dropout层。

数据预处理

采用了两对德语→英语翻译的句子,每个字的索引通过手动硬编码,降低代码阅读难度。构建编码器的输入、解码器的输入、解码器的输出即真实标签。

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

# S: 开始标志
# E: 结束标志
# P: 如果当前批处理数据长度小于最大长度(自己设置的),将填充空白字符
sentences = [
    # enc_input 编码端输入       dec_input 解码端输入    dec_output 解码端的真实标签
    ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# 构建源数据词表和目标数据词表
# Padding Should be Zero
src_vocab = {
    
    'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4, 'cola': 5}
src_vocab_size = len(src_vocab)  # 6

tgt_vocab = {
    
    'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'coke': 5, 'S': 6, 'E': 7, '.': 8}
tgt_vocab_size = len(tgt_vocab)  # 9

# 索引转化为单词:{0:'P',1:'i',2:'want',...,8:'.'},用于预测
idx2word = {
    
    i: w for i, w in enumerate(tgt_vocab)}  # i是index,w是key

src_len = 5  # enc_input max sequence length
tgt_len = 6  # dec_input(=dec_output) max sequence length


# 构建编码器输入enc_inputs,解码器输入dec_inputs,解码器输出dec_outputs即真实标签
def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
        enc_input = [[src_vocab[n] for n in sentences[i][0].split()]]
        dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]]
        dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]]

        enc_inputs.extend(enc_input)  # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
        dec_inputs.extend(dec_input)  # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
        dec_outputs.extend(dec_output)  # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)


enc_inputs, dec_inputs, dec_outputs = make_data(sentences)  # 输出为张量
# enc_inputs: [batch_size, src_len]=[2,5]
# dec_inputs/dec_outputs: [batch_size, tgt_len]=[2,6]

class MyDataSet(Data.Dataset):
    def __init__(self, enc_inputs, dec_inputs, dec_outputs):
        super(MyDataSet, self).__init__()
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs

    def __len__(self):
        return self.enc_inputs.shape[0]  # 2

    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]


# 由于只有两个句子,这里batch_size设置为2
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), batch_size=2, shuffle=True)

Positional Encoding

每个位置的变化方式如式:
P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos,2i)}=sin(pos/10000^ {2i/d_ {model}}) \\ PE_{(pos,2i+1)}=cos(pos/10000^ {2i/d_ {model}}) PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # 初始化pe
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # 构建pos
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数用sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数用cos
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: 词向量序列[seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

这里画出图来看一下位置编码:

import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
pe = PositionalEncoding(20, 0)
y = pe.forward((torch.zeros(100, 1, 20)))
plt.plot(np.arange(100), y[:, 0, 4:8].data.numpy())
plt.legend(["dim %d"%p for p in [4,5,6,7]])
None

quxian

模型参数

d_model = 512  # Embedding Size
d_ff = 2048  # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

get_attn_pad_mask

def get_attn_pad_mask(seq_q, seq_k):
    """
    seq_q: [batch_size, len_q]
    seq_k: [batch_size, len_k]
     seq_q 和 seq_k 不一定一致,len_q与len_k可能不相等
    """
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token, 为0设置为 True
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k] 只使用seq_k的pad信息
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

打印出来看一下效果:

dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) 
# 在交互注意力层,只用到了enc_inputs的pad信息,没有用到解码端的pad信息
dec_enc_attn_mask

# 输出 shape:(2,6,5)
tensor([[[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])

get_attn_subsequence_mask

# 解码端 Masked Multi-Head Attention 的 Masked来源,便于并行计算
def get_attn_subsequence_mask(seq):
    """
    seq: 输入的是dec_inputs [batch_size, tgt_len]
    """
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]  # [batch_size, tgt_len, tgt_len]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 上三角为1的矩阵,k=1设置对角线元素为0
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()  # 变为张量
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]

打印出来看一下效果:

get_attn_subsequence_mask(dec_inputs)

# 输出 shape:(2,6,6)
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
import matplotlib.pyplot as plt
a = torch.randn((5, 20))  # 随机生成标准正态分布数 [batch_size, len]
plt.figure(figsize=(5, 5))
plt.imshow(get_attn_subsequence_mask(a)[0])  # [batch_size, len, len] 显示第0个
None

Scaled Dot-Product Attention

Scaled Dot-Product Attention 是 Multi-Head Attention 的一部分。

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QKT)V

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        """
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        """
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # [,,len_q,d_k]*[,,d_k,len_k]=[,,len_q,len_k]
        # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9)  # mask is True的位置设置为负无穷,经过softmax后为0
        # 掩码 attn_mask 与 scores 的维度相同 [batch_size, n_heads, len_q, len_k]

        attn = nn.Softmax(dim=-1)(scores)  # [batch_size, n_heads, len_q, len_k]
        # attn 为经过softmax之后的相似概率分布,每一行概率和为1
        context = torch.matmul(attn, V)   # [,,len_q,len_k]*[,,len_v(=len_k),d_v]=[,,len_q,d_v]
        # context: QKV经过自注意力机制计算后的值, [batch_size, n_heads, len_q, d_v]
        return context, attn

Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):
        """
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        :return: 经过多头注意力+残差+LayerNorm后的输出,保持和input_Q相同的维度
        """
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        # attn_mask : [batch_size, n_heads, seq_len, seq_len]
        # repeat(): 在第2维复制n_heads次,在其他维是1次。

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v)  # 这一步是图中的 cancat
        # context: [batch_size, len_q, n_heads * d_v]

        output = self.fc(context)  # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).cuda()(output + residual), attn  # 经过残差和LayerNorm不改变维度

Feed Forward Net

# 前馈神经网络,输入输出维度不变
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )

    def forward(self, inputs):
        """
        inputs: [batch_size, seq_len, d_model] 
        """
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).cuda()(output + residual)  # [batch_size, seq_len, d_model]

Encoder Layer

# 包含多头自注意力机制+前馈神经网络
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()  # 命名:编码器-自注意力
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        """
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        """
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)  # Q,K,V同源
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs = self.pos_ffn(enc_outputs)  # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn  # enc_outputs 的维度与 enc_inputs 维度相同

Encoder

# Encoder 部分包含三个部分:词向量embedding,位置编码,n_layers 层EncoderLayer(注意力层+FFN)
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])  # 使用ModuleList堆叠多个EncoderLayer

    def forward(self, enc_inputs):
        """
        enc_inputs: torch.Size([batch_size, src_len])
        """
        enc_outputs = self.src_emb(enc_inputs)  # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1)  # [batch_size, src_len, d_model]
        # 前面位置编码中的输入为[seq_len, batch_size, d_model],所以要transpose前两个维度
        # 经过位置编码后,保持输入输出维度不变
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]

        enc_self_attns = []
        for layer in self.layers:
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_self_attns.append(enc_self_attn)  # 列表,长度为 n_layers
        return enc_outputs, enc_self_attns

Decoder Layer

# 包含三个部分:掩码多头自注意力 + 编码-解码多头注意力 + FFN
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()  # 命名:解码-自注意力
        self.dec_enc_attn = MultiHeadAttention()  # 命名:解码-编码-注意力
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        """
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        return: dec_outputs 保持与 dec_inputs 维度相同
        """
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)  # Q,K,V同源
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]

        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        # Q来自解码器端经过掩码多头自注意力的输出, K、V来自经过6层编码层后的输出
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]

        dec_outputs = self.pos_ffn(dec_outputs)  # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

Decoder

# Decoder 部分包含三个部分:词向量embedding,位置编码,n_layers 层DecoderLayer
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        """
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        """
        dec_outputs = self.tgt_emb(dec_inputs)  # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda()  # [batch_size, tgt_len, d_model]


        # dec_self_attn_pad_mask 自注意力机制中的 pad 部分,这个是bool类型:
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda()  # [batch_size, tgt_len, tgt_len]

        # dec_self_attn_subsequence_mask 做自注意层的mask部分,即当前单词之后的单词看不到,使用一个上三角为1的矩阵
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda()  # [batch_size, tgt_len, tgt_len]

        # 两个矩阵相加,大于0的为1,不大于0的为0,为1的在之后就会被fill填充为无限小
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda()
        # 变成 bool 类型 [batch_size, tgt_len, tgt_len]

        # 生成交互注意力机制中的 mask 矩阵
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)  # [batch_size, tgt_len, src_len] [2, 6, 5]
        # 也就是说 自注意力层用的是 dec_self_attn_mask, 交互注意力层用的是 dec_enc_attn_mask

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask,
                                                             dec_enc_attn_mask)
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, n_heads, tgt_len, src_len]
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

Transformer

# 包含 编码层 + 解码层 + 线性层
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder().cuda()
        self.decoder = Decoder().cuda()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()

    def forward(self, enc_inputs, dec_inputs):
        """
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        """
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len]
        # dec_enc_attn: [n_layers, batch_size, n_heads, tgt_len, src_len]
        dec_logits = self.projection(dec_outputs)  # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
        #  展平成目标词表长度,用于计算损失 (batch_size * tgt_len, tgt_vocab_size)

模型 损失函数 优化器

损失函数中,设置了一个参数 ignore_index=0,因为 “pad” 这个单词的索引为 0,这样设置以后,就会忽略计算 “pad” 的损失(因为本来 “pad” 也没有意义,不需要计算)。

model = Transformer().cuda()
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 最后的softmax在这里,用于计算交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)  # 随机梯度下降

训练

for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        # enc_inputs: [batch_size, src_len] 张量
        # dec_inputs: [batch_size, tgt_len]
        # dec_outputs: [batch_size, tgt_len]

        enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        loss = criterion(outputs, dec_outputs.view(-1))  # dec_outputs变为[batch_size * tgt_len]
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

输出结果:

Epoch: 0001 loss = 1.058965
Epoch: 0002 loss = 0.938208
Epoch: 0003 loss = 0.738537
Epoch: 0004 loss = 0.628805
Epoch: 0005 loss = 0.472079
Epoch: 0006 loss = 0.394795
......

测试

打断点观察预测过程:

# 预测时,不知道目标序列输入。因此,尝试逐字生成目标输入,然后将其输入到Transformer中。
# 预测的时候编码器中,以start_symbol作为起始输入
# 之后每一轮输出的预测值作为下一轮的输入,直至预测出'.'的index停止
def greedy_decoder(model, enc_input, start_symbol):  # start_symbol=6,int
    """
    :param model: Transformer Model
    :param enc_input: The encoder input [1, src_len] 
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 6
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    # 经过编码器之后,enc_input:(1,src_len) -> enc_outpus:(1,src_len,512)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)  # tensor([])
    terminal = False
    next_symbol = start_symbol
    while not terminal:  # 循环 从 ["S"] 开始,词向量表索引是tensor(6)
        dec_input = torch.cat([dec_input.detach(), torch.tensor([[next_symbol]], dtype=enc_input.dtype).cuda()], -1)
        # shape/data: (1,1)/([[6]]) -> (1,2)/([[6,1]]) -> (1,3)/([[6,1,2]])/ ->...
        # 上一轮的预测值作为下一轮的输入
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        # 经过解码器之后,dec_outputs:(1,1,512)->(1,2,512)->(1,3,512)->...
        projected = model.projection(dec_outputs)  # (1,1,9)->(1,2,9)->(1,3,9)->...
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]  # 按照最后一维找出值最大的,即预测的字的索引
        # shape/data: (1,)/tensor([1]) -> (2,)/tensor([1,2]) -> (3,)/tensor([1,2,3]) ->...
        # [1] 指最后返回的是最大值位置的索引
        next_word = prob.data[-1]  # 选取prob的位置索引中最后一个数,tensor(1)->tensor(2)->tensor(3)->...
        next_symbol = next_word
        if next_symbol == tgt_vocab["."]:  # 直至是".",即词向量表是8的话就终止
            terminal = True
        print(next_word)  # tensor(1)->(2)->(3)-> (4)-> (8)
    return dec_input


# Test
enc_inputs, _, _ = next(iter(loader))  # (2,5)
enc_inputs = enc_inputs.cuda()
for i in range(len(enc_inputs)):  # 长为2
    greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])  # [[6,1,2,3,4]] 
    predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)  # 输入:shape(1,5) 预测:shape(5,9) 
    predict = predict.data.max(1, keepdim=True)[1]  # 找出最大值索引 (5,1)
    print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])

输出结果:

tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(4, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 4, 0], device='cuda:0') -> ['i', 'want', 'a', 'beer', '.']
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(5, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 5, 0], device='cuda:0') -> ['i', 'want', 'a', 'coke', '.']

这里不放全部代码了,只要将上面提及的代码(除位置编码等打印结果的代码)复制粘贴下来,就能运行。

猜你喜欢

转载自blog.csdn.net/qq_45670134/article/details/128005237