从零学习大模型(四)-----代码实现交替的稠密与本地稀疏注意力模式

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # 创建位置编码矩阵,形状为 (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        # 创建位置的张量 (0, 1, 2, ..., max_len-1) 并扩展其维度
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # 计算正弦和余弦函数的除数项
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
        # 对位置编码的偶数索引应用正弦函数,奇数索引应用余弦函数
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # 添加一个额外的维度以便与批次兼容
        pe = pe.unsqueeze(0).transpose(0, 1)
        # 注册位置编码为缓冲区,在训练期间不更新
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 将位置编码加到输入的嵌入上
        return x + self.pe[:x.size(0), :]


class TransformerEncoderLayerWithAlternatingAttention(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerEncoderLayerWithAlternatingAttention, self).__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.dropout = nn.Dropout(dropout)
        # 前馈网络,包含两个线性层和一个激活函数(ReLU)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        # 激活函数
        self.activation = F.relu

    def dense_attention(self, query, key, value, mask=None):
        # 计算注意力分数
        scores = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(self.d_model // self.nhead)
        # 如果有掩码,则将不需要的部分设为负无穷大
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)
        # 使用注意力权重对值进行加权求和
        return torch.matmul(attn_weights, value)

    def sparse_attention(self, query, key, value, mask=None, window_size=3):
        # 实现本地稀疏注意力,只关注局部窗口
        batch_size, nhead, seq_len, head_dim = query.size()
        attn_output = torch.zeros_like(query)
        # 对每个时间步进行局部注意力计算
        for i in range(seq_len):
            start = max(0, i - window_size)
            end = min(seq_len, i + window_size + 1)
            local_query = query[:, :, i:i+1, :]
            local_key = key[:, :, start:end, :]
            local_value = value[:, :, start:end, :]
            local_mask = None if mask is None else mask[:, :, i:i+1, start:end]
            # 使用密集注意力计算局部区域的注意力
            local_attn = self.dense_attention(local_query, local_key, local_value, mask=local_mask)
            attn_output[:, :, i:i+1, :] = local_attn
        return attn_output

    def forward(self, src, src_mask=None, attention_type='dense'):
        batch_size, seq_len, _ = src.size()
        head_dim = self.d_model // self.nhead

        # 将输入分割成多个头,形状为 (batch_size, nhead, seq_len, head_dim)
        query = key = value = src.view(batch_size, seq_len, self.nhead, head_dim).transpose(1, 2)

        # 根据指定的注意力类型选择计算方式
        if attention_type == 'dense':
            attn_output = self.dense_attention(query, key, value, mask=src_mask)
        elif attention_type == 'sparse':
            attn_output = self.sparse_attention(query, key, value, mask=src_mask)
        else:
            raise ValueError("Unsupported attention type: choose 'dense' or 'sparse'")

        # 恢复原始形状并应用丢弃层
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        attn_output = self.dropout(attn_output)

        # 残差连接和层归一化
        src = self.norm1(src + attn_output)
        # 前馈网络部分
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        # 残差连接和层归一化
        src = self.norm2(src + src2)
        return src


class AlternatingAttentionTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_len):
        super(AlternatingAttentionTransformer, self).__init__()
        # 嵌入层,将标记索引转换为稠密向量
        self.embedding = nn.Embedding(vocab_size, d_model)
        # 位置编码,用于将序列信息添加到嵌入中
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        # 创建多个编码器层,每一层交替使用稠密和稀疏注意力
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayerWithAlternatingAttention(
                d_model, nhead, dim_feedforward, dropout=0.1
            ) for _ in range(num_encoder_layers)
        ])
        self.d_model = d_model
        # 解码层,将编码器的输出投影到词汇表大小
        self.decoder = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_mask):
        # 对源标记应用嵌入层,并按 sqrt(d_model) 进行缩放
        src = self.embedding(src) * np.sqrt(self.d_model)
        # 将位置编码加到嵌入后的标记上
        src = self.pos_encoder(src)
        # 通过所有的编码器层,交替使用稠密和稀疏注意力
        for i, layer in enumerate(self.encoder_layers):
            attention_type = 'dense' if i % 2 == 0 else 'sparse'
            src = layer(src, src_mask, attention_type=attention_type)
        # 将输出投影到词汇表大小
        output = self.decoder(src)
        return output

    def generate_square_subsequent_mask(self, sz):
        # 生成一个掩码,以防模型关注未来的位置
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


# 示例使用
vocab_size = 10000  # 假设词汇表大小为 10000
d_model = 128  # 嵌入向量的维度
nhead = 4  # 注意力头的数量
num_encoder_layers = 6  # 编码器层的数量
dim_feedforward = 512  # 前馈网络的维度
max_len = 50  # 输入序列的最大长度

# 实例化模型
model = AlternatingAttentionTransformer(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_len)

# 假设输入是一个长度为 10 的序列
input_sequence = torch.randint(0, vocab_size, (10, 1))
src_mask = model.generate_square_subsequent_mask(len(input_sequence))

# 前向传播
output = model(input_sequence, src_mask)
print(output.shape)  # 输出的形状应为 (seq_len, batch_size, vocab_size)

猜你喜欢

转载自blog.csdn.net/red_guy/article/details/143166423