Baichuan2 source code analysis: Baichuan2-13B-Chat/modelling_baichuan.py

# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.

from .configuration_baichuan import BaichuanConfig
from .generation_utils import build_chat_input, TextIterStreamer

import math
from threading import Thread
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging, ContextManagers

import os
from contextlib import contextmanager
from accelerate import init_empty_weights

logger = logging.get_logger(__name__)

try:
    from xformers import ops as xops
except ImportError:
    xops = None
    logger.warning(
        "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
    )


def _get_interleave(n):
    def _get_interleave_power_of_2(n):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return _get_interleave_power_of_2(n)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            _get_interleave_power_of_2(closest_power_of_2)
            + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )


def _fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)


def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
    _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
    _future_mask = _future_mask.unsqueeze(0) + alibi
    new_future_mask = _future_mask.to(tensor)
    return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]


def _gen_alibi_mask(tensor, n_head, max_pos):
    slopes = torch.Tensor(_get_interleave(n_head))
    position_point = torch.arange(max_pos) - max_pos + 1
    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
    diag = torch.diag(position_point[0])
    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
    alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
    alibi = alibi.view(n_head, 1, max_pos)
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
    return alibi_mask


class RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, epsilon=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(hidden_size))
        self.epsilon = epsilon

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)

        # convert into half-precision
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * hidden_states


class MLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class BaichuanAttention(torch.nn.Module):
    def __init__(self, config: BaichuanConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.max_position_embeddings = config.model_max_length

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
            )
        self.W_pack = torch.nn.Linear(
            self.hidden_size, 3 * self.hidden_size, bias=False
        )
        self.o_proj = torch.nn.Linear(
            self.num_heads * self.head_dim, self.hidden_size, bias=False
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return (
            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
            .contiguous()
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        proj = self.W_pack(hidden_states)
        proj = (
            proj.unflatten(-1, (3, self.hidden_size))
            .unsqueeze(0)
            .transpose(0, -2)
            .squeeze(-2)
        )
        query_states = (
            proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        key_states = (
            proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        value_states = (
            proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None
        if xops is not None and self.training:
            attn_weights = None
            # query_states = query_states.transpose(1, 2)
            # key_states = key_states.transpose(1, 2)
            # value_states = value_states.transpose(1, 2)
            # attn_output = xops.memory_efficient_attention(
            #     query_states, key_states, value_states, attn_bias=attention_mask
            # )
            with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
                attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
            attn_output = attn_output.transpose(1, 2)
        else:
            attn_weights = torch.matmul(
                query_states, key_states.transpose(2, 3)
            ) / math.sqrt(self.head_dim)

            if attention_mask is not None:
                if q_len == 1:  # inference with cache
                    if len(attention_mask.size()) == 4:
                        attention_mask = attention_mask[:, :, -1:, :]
                    else:
                        attention_mask = attention_mask[:, -1:, :]
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(
                    attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
                )

            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
            attn_output = torch.matmul(attn_weights, value_states)

            attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
  1. from .configuration_baichuan import BaichuanConfig

    Import classes configuration_baichuanin modules under the current package BaichuanConfig.

  2. from .generation_utils import build_chat_input, TextIterStreamer

    Import and generation_utilsin modules under the current package .build_chat_inputTextIterStreamer

  3. import math

    Import Python's built-in math function library.

  4. from threading import Thread

    Import the class in Python's multithreading libraryThread .

  5. from typing import List, Optional, Tuple, Union

    ListImport Python's type annotation library. Here , , Optional, Tupleand are imported Union.

  6. import torch

    Import the PyTorch framework.

  7. from torch import nn

    Import the neural network library from PyTorch.

  8. from torch.nn import CrossEntropyLoss

    Import the cross-entropy loss function from PyTorch's neural network library.

  9. from torch.nn import functional as F

    Import the function module from PyTorch's neural network library and give it an alias F.

  10. from transformers import PreTrainedModel, PretrainedConfig

Import and from transformersthe library .PreTrainedModelPretrainedConfig

  1. from transformers.activations import ACT2FN

    transformersImport the activation function mapping table from the library ACT2FN.

  2. from transformers.generation.utils import GenerationConfig

    Import from module transformersin library .generation.utilsGenerationConfig

  3. from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

    Import and from modulestransformers in the library .modeling_outputsBaseModelOutputWithPastCausalLMOutputWithPast

  4. from transformers.utils import logging, ContextManagers

    Import and from modules transformersin the library .utilsloggingContextManagers

  5. import os

    Import Python's built-in OS module to handle operating system-related tasks.

  6. from contextlib import contextmanager

    Imported from Python's contextlibmodule contextmanager, it is used to create context managers.

  7. from accelerate import init_empty_weights

    Import functions from acceleratelibraries init_empty_weights.

  8. logger = logging.get_logger(__name__)

    transformersCreates a logger object using the provided logging functionality, __name__which is the name of the current module.

19-26. try... except ImportErrorcode block:

 
 

goCopy code

试图从 `xformers` 库中导入 `ops` 模块,并为其取别名`xops`。如果导入失败(即没有正确安装`xformers`库),则将`xops`设为`None`,并发出一个警告消息。

28-39. _get_interleave(n):

 
 

goCopy code

定义了一个辅助函数`_get_interleave`。这个函数有一个内嵌函数`_get_interleave_power_of_2`,用于计算并返回一个列表,该列表的长度与给定的数字`n`相同,元素为从开始值开始的等比数列。主函数`_get_interleave`根据`n`是否是2的整数次幂来调用内部函数,并返回一个列表。

41-43. _fill_with_neg_inf(t):

 
 

goCopy code

定义了一个辅助函数`_fill_with_neg_inf`,它接受一个张量`t`,将其填充为负无穷大,并返回该张量。

45-50. _buffered_future_mask(tensor, maxpos, alibi, attn_heads):

 
 

goCopy code

定义了一个辅助函数`_buffered_future_mask`,该函数用于生成一个未来掩码(常用于Transformer的自注意力机制)。

52-64. _gen_alibi_mask(tensor, n_head, max_pos):

 
 

goCopy code

定义了一个辅助函数`_gen_alibi_mask`,用于生成一个"alibi"掩码。

66-77.Class RMSNorm:

 
 

Copy code

定义了一个层归一化的变体:RMSNorm。这是一个神经网络模块,其核心功能是通过平方的均值进行归一化。

79-93.Class MLP:

 
 

goCopy code

定义了一个多层感知机(MLP)类,这是一个神经网络模块,包含三个线性层和一个激活函数。在其前向传播中,输入`x`首先经过`gate_proj`层和激活函数,然后与`up_proj`的输出相乘,最后经过`down_proj`层。

This code mainly contains some auxiliary functions and two neural network modules: RMSNormandMLP . These features may be used in larger Transformer models or other neural network models.

Baichuan_attention:
 

This is a BaichuanAttentionclass named which defines a self-attention mechanism module, which is similar to the attention mechanism in BERT, Transformer and other models. Here is a line-by-line explanation of the code:

  1. class BaichuanAttention(torch.nn.Module):Define a BaichuanAttentionclass named , which it inherits from torch.nn.Module, meaning this is a PyTorch neural network module.

  2. def __init__(self, config: BaichuanConfig):Define a constructor that accepts a BaichuanConfigparameter of type.

  3. super().__init__()Calling the parent class's constructor is a normal operation when defining your own network layer in PyTorch.

  4. self.config = configSave the incoming configuration as a property of the class.

  5. self.hidden_size = config.hidden_sizeGet the property from the configuration hidden_sizeand save it as a property of the class.

  6. self.num_heads = config.num_attention_headsGet the number of attention heads from the configuration and save it as a property of the class.

  7. self.head_dim = self.hidden_size // self.num_headsThe dimensions of each attention head are calculated and saved as attributes of the class.

  8. self.max_position_embeddings = config.model_max_lengthGet the maximum length of the model from the configuration and save it as a property of the class.

9-12. if (self.head_dim * self.num_heads) != self.hidden_size:Verify whether the hidden layer size is divisible by the number of attention heads.

13-15. self.W_pack = torch.nn.Linear(...)Define a linear layer that linearly transforms the hidden state of the input to obtain the query, key, and value.

16-18. self.o_proj = torch.nn.Linear(...)Define a linear layer for the output after the attention mechanism.

19-22. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):Define an auxiliary function that reshapes the given tensor to adapt to the attention mechanism.

23-32. def forward(...):Define the forward propagation function of the module.

  1. bsz, q_len, _ = hidden_states.size()Get the batch size, sequence length, and hidden layer size of the input tensor.

34-38. This part of the code linearly transforms the input hidden state to obtain the query, key, and value.

39-48. This part of the code shapes the query, keys, and values ​​to fit the attention computation.

49-54. If provided past_key_value, concatenates with the current key and value.

55-56. If use_cacheis True, save the key and value.

57-68. Determine whether it is installed xformersand use different attention calculation methods according to conditions.

69-81. If not used xformers, the conventional scaling dot product self-attention calculation method is used.

82-83. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)Adjust the shape of attention output.

  1. attn_output = self.o_proj(attn_output)Pass the attention output through a linear layer.

85-87. Depending on output_attentionsthe value of , the attention weight may be returned or set to None.

  1. return attn_output, attn_weights, past_key_valueReturns the attention output, attention weight, and past key-value pairs.

This module implements scaling dot product self-attention, which is a key component in the Transformer architecture .

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/133090157