O código-fonte para alcançar mecanismo Máscara Transformer

mecanismo de máscara princípio de que, no lado do descodificador, a previsão é baseada em informações codificador ea palavra previsto, e em fase de codificador, Self_Attention não têm esse mecanismo, é essencialmente uma máscara para a atenção dele, então olhamos para a Atenção implementação:

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))  / math.sqrt(d_k)
    # 这里是对应公式的  Q* K的转秩矩阵
    """
    Queries张量,形状为[B, H, L_q, D_q]
    Keys张量,形状为[B, H, L_k, D_k]
    Values张量,形状为[B, H, L_v, D_v],一般来说就是k
    """
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

Sabemos que em formação, que são unidades batch_size, em seguida, haverá preenchimento, geralmente tomamos o pad == 0, em seguida, ele fará com que a atenção do tempo, consulta o valor é 0, consulta um valor de 0 , o valor pontuação correspondente de nossos cálculos é 0, ele provavelmente vai levar a softmax atribuído à palavra não é uma proporção relativamente pequena, portanto, vamos preencher o valor de pontuação correspondente é infinito negativo , a fim de reduzir afectar pequeno bloco que está no acima scores = scores.masked_fill(mask == 0, -1e9)significativo. de modo que pode facilmente imaginar, no descodificador, não previsto palavra é adicionada ao lote com o enchimento da forma, de modo que o mecanismo de máscara quando o mecanismo de máscara e preenchimento usado é o mesmo, é essencialmente consulta é 0, mas uma matriz máscara diferente, podemos encontrar nesta parte do decodificador de acordo com o código.

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        # 对源语言与目标语言的 mask 机制
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        # Self_Attention 机制, 是针对目标语言的, 因此需要引入 tgt_mask, 这个mask 矩阵是由已预测出的单词构成的, 
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        # 这个是对 encoder 的结果的 Attention, 由于 encoder 阶段有 Padding, 所以这个 mask 矩阵和 encoder 阶段的mask 矩阵是一样的
        return self.sublayer[2](x, self.feed_forward)

Em seguida vamos dar uma olhada retrospectiva, mascarar aqui é como é que, finalmente construído módulo é Encoder_Decoder,

class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        # 将源语言的单词 embedding 放在一起, position embedding
        self.tgt_embed = tgt_embed
        # 将目标语言的单词 embedding 放在一起, position embedding
        self.generator = generator
        # 就是最后产生结果的地方

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

Quando treinado, usando model.forward, nesta parte:

def run_epoch(args, data_iter, model, loss_compute, valid_params=None, epoch_num=0,
              is_valid=False, is_test=False, logger=None):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    if valid_params is not None:
        src_dict, tgt_dict, valid_iter = valid_params
        hist_valid_scores = []

    bleu_all = 0
    count_all = 0

    for i, batch in enumerate(data_iter):
        model.train()

        out = model.forward(batch.src, batch.trg ,batch.src_mask, batch.trg_mask)
        # 参数来自 batch
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        # 这一步既计算了损失, 又更新了参数
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens

Estes são a etapa de treinamento, os dados é como é que, a matriz máscara do lote, por isso o mais crítico é o lote é como é que, recuando até achar em função train.py, encontramos

 _, logger_file = train_utils.run_epoch(args, (train_utils.rebatch(pad_idx, b) for b in train_iter),
                                  model_parallel if args.multi_gpu else model, train_loss_fn,
                                  valid_params=valid_params,
                                  epoch_num=epoch, logger=logger_file)

lote de função rebatch e iteradores de treinamento de dados, este train_iter foi baseada em torchtext, não entrar aqui, assim que a chave é seguinte função rebatch,

def rebatch(pad_idx, batch):
    "Fix order in torchtext"
    src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
    # 读的数据是 sequence * batch_size 的吗, 是在torchtext 中的Filed 决定的
    # 所以需要转换为 bacth * sequence
    return Batch(src, trg, pad_idx)

Finalmente encontrou classe de lote, o máximo de informação crítica a partir daqui:

class Batch:
    "Object for holding a batch of data with mask during training."

    def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        # 在预测的时候是没有 tgt 的,此时为 None
        if trg is not None:
            self.trg = trg[:, :-1]
            # 每次迭代的时候, 去掉最后一个单词
            self.trg_y = trg[:, 1:]
            # 去掉第一个单词
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).sum().item()
            # target 语言中单词的个数

    @staticmethod
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & transformer.subsequent_mask(tgt.size(-1)).type_as(tgt_mask)
        # tgt.size(-1) 表示的是序列的长度
        return tgt_mask

No lote classe, trg para Nenhum quando bem compreendido, ou seja, prever quando a língua-alvo não é, de fato, no momento previsto, Batch única entrada, em seguida, processo de previsão Máscara atenção e como alcançá-lo ? vamos colocar isso de volta novamente, olhada src_mask aqui, mascarar o idioma de origem, isto é, quando a máscara quando codificador self_Attention, isso é bem compreendido, é tornar-se um não-zero de números, para se obter uma matriz 0/1 , self.trg = trg[:, :-1]a última palavra aqui para se livrar, nem uma palavra real, mas sinais '<eos>', entrada e saída são ainda um '<sos>' no início de uma frase, self.trg_y = trg[:, 1:]remover o início torna-se resultado final. acesso para baixo é a matriz máscara alvo de aquisição mais linguagem crítica,

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

Esta função fez Shane?

Primeiro, escreveu o seguinte:

def subsequentmask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return subsequent_mask == 0

print(subsequentmask(5))

>>

[[[ True False False False False]
  [ True  True False False False]
  [ True  True  True False False]
  [ True  True  True  True False]
  [ True  True  True  True  True]]]

Quando esta matriz em tensor numpy quando a configuração é a dimensão (1, 5, 5) de matriz, um meio de uma frase

Acho que você gosta

Origin www.cnblogs.com/wevolf/p/12484972.html
Recomendado
Clasificación