# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
from .configuration_baichuan import BaichuanConfig
from .generation_utils import build_chat_input, TextIterStreamer
import math
from threading import Thread
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging, ContextManagers
import os
from contextlib import contextmanager
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
try:
from xformers import ops as xops
except ImportError:
xops = None
logger.warning(
"Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
)
def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
def _fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
new_future_mask = _future_mask.to(tensor)
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
def _gen_alibi_mask(tensor, n_head, max_pos):
slopes = torch.Tensor(_get_interleave(n_head))
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, epsilon=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.epsilon = epsilon
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
# convert into half-precision
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class MLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class BaichuanAttention(torch.nn.Module):
def __init__(self, config: BaichuanConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.model_max_length
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
)
self.W_pack = torch.nn.Linear(
self.hidden_size, 3 * self.hidden_size, bias=False
)
self.o_proj = torch.nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states)
proj = (
proj.unflatten(-1, (3, self.hidden_size))
.unsqueeze(0)
.transpose(0, -2)
.squeeze(-2)
)
query_states = (
proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
key_states = (
proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
value_states = (
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if xops is not None and self.training:
attn_weights = None
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
# value_states = value_states.transpose(1, 2)
# attn_output = xops.memory_efficient_attention(
# query_states, key_states, value_states, attn_bias=attention_mask
# )
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
attn_output = attn_output.transpose(1, 2)
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None:
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
-
from .configuration_baichuan import BaichuanConfig
Import classes
configuration_baichuan
in modules under the current packageBaichuanConfig
. -
from .generation_utils import build_chat_input, TextIterStreamer
Import and
generation_utils
in modules under the current package .build_chat_input
TextIterStreamer
-
import math
Import Python's built-in math function library.
-
from threading import Thread
Import the class in Python's multithreading library
Thread
. -
from typing import List, Optional, Tuple, Union
List
Import Python's type annotation library. Here , ,Optional
,Tuple
and are importedUnion
. -
import torch
Import the PyTorch framework.
-
from torch import nn
Import the neural network library from PyTorch.
-
from torch.nn import CrossEntropyLoss
Import the cross-entropy loss function from PyTorch's neural network library.
-
from torch.nn import functional as F
Import the function module from PyTorch's neural network library and give it an alias
F
. -
from transformers import PreTrainedModel, PretrainedConfig
Import and from transformers
the library .PreTrainedModel
PretrainedConfig
-
from transformers.activations import ACT2FN
transformers
Import the activation function mapping table from the libraryACT2FN
. -
from transformers.generation.utils import GenerationConfig
Import from module
transformers
in library .generation.utils
GenerationConfig
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Import and from modules
transformers
in the library .modeling_outputs
BaseModelOutputWithPast
CausalLMOutputWithPast
-
from transformers.utils import logging, ContextManagers
Import and from modules
transformers
in the library .utils
logging
ContextManagers
-
import os
Import Python's built-in OS module to handle operating system-related tasks.
-
from contextlib import contextmanager
Imported from Python's
contextlib
modulecontextmanager
, it is used to create context managers. -
from accelerate import init_empty_weights
Import functions from
accelerate
librariesinit_empty_weights
. -
logger = logging.get_logger(__name__)
transformers
Creates a logger object using the provided logging functionality,__name__
which is the name of the current module.
19-26. try
... except ImportError
code block:
goCopy code
试图从 `xformers` 库中导入 `ops` 模块,并为其取别名`xops`。如果导入失败(即没有正确安装`xformers`库),则将`xops`设为`None`,并发出一个警告消息。
28-39. _get_interleave(n)
:
goCopy code
定义了一个辅助函数`_get_interleave`。这个函数有一个内嵌函数`_get_interleave_power_of_2`,用于计算并返回一个列表,该列表的长度与给定的数字`n`相同,元素为从开始值开始的等比数列。主函数`_get_interleave`根据`n`是否是2的整数次幂来调用内部函数,并返回一个列表。
41-43. _fill_with_neg_inf(t)
:
goCopy code
定义了一个辅助函数`_fill_with_neg_inf`,它接受一个张量`t`,将其填充为负无穷大,并返回该张量。
45-50. _buffered_future_mask(tensor, maxpos, alibi, attn_heads)
:
goCopy code
定义了一个辅助函数`_buffered_future_mask`,该函数用于生成一个未来掩码(常用于Transformer的自注意力机制)。
52-64. _gen_alibi_mask(tensor, n_head, max_pos)
:
goCopy code
定义了一个辅助函数`_gen_alibi_mask`,用于生成一个"alibi"掩码。
66-77.Class RMSNorm
:
Copy code
定义了一个层归一化的变体:RMSNorm。这是一个神经网络模块,其核心功能是通过平方的均值进行归一化。
79-93.Class MLP
:
goCopy code
定义了一个多层感知机(MLP)类,这是一个神经网络模块,包含三个线性层和一个激活函数。在其前向传播中,输入`x`首先经过`gate_proj`层和激活函数,然后与`up_proj`的输出相乘,最后经过`down_proj`层。
This code mainly contains some auxiliary functions and two neural network modules: RMSNorm
andMLP
. These features may be used in larger Transformer models or other neural network models.
Baichuan_attention:
This is a BaichuanAttention
class named which defines a self-attention mechanism module, which is similar to the attention mechanism in BERT, Transformer and other models. Here is a line-by-line explanation of the code:
-
class BaichuanAttention(torch.nn.Module):
Define aBaichuanAttention
class named , which it inherits fromtorch.nn.Module
, meaning this is a PyTorch neural network module. -
def __init__(self, config: BaichuanConfig):
Define a constructor that accepts aBaichuanConfig
parameter of type. -
super().__init__()
Calling the parent class's constructor is a normal operation when defining your own network layer in PyTorch. -
self.config = config
Save the incoming configuration as a property of the class. -
self.hidden_size = config.hidden_size
Get the property from the configurationhidden_size
and save it as a property of the class. -
self.num_heads = config.num_attention_heads
Get the number of attention heads from the configuration and save it as a property of the class. -
self.head_dim = self.hidden_size // self.num_heads
The dimensions of each attention head are calculated and saved as attributes of the class. -
self.max_position_embeddings = config.model_max_length
Get the maximum length of the model from the configuration and save it as a property of the class.
9-12. if (self.head_dim * self.num_heads) != self.hidden_size:
Verify whether the hidden layer size is divisible by the number of attention heads.
13-15. self.W_pack = torch.nn.Linear(...)
Define a linear layer that linearly transforms the hidden state of the input to obtain the query, key, and value.
16-18. self.o_proj = torch.nn.Linear(...)
Define a linear layer for the output after the attention mechanism.
19-22. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
Define an auxiliary function that reshapes the given tensor to adapt to the attention mechanism.
23-32. def forward(...):
Define the forward propagation function of the module.
bsz, q_len, _ = hidden_states.size()
Get the batch size, sequence length, and hidden layer size of the input tensor.
34-38. This part of the code linearly transforms the input hidden state to obtain the query, key, and value.
39-48. This part of the code shapes the query, keys, and values to fit the attention computation.
49-54. If provided past_key_value
, concatenates with the current key and value.
55-56. If use_cache
is True
, save the key and value.
57-68. Determine whether it is installed xformers
and use different attention calculation methods according to conditions.
69-81. If not used xformers
, the conventional scaling dot product self-attention calculation method is used.
82-83. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Adjust the shape of attention output.
attn_output = self.o_proj(attn_output)
Pass the attention output through a linear layer.
85-87. Depending on output_attentions
the value of , the attention weight may be returned or set to None
.
return attn_output, attn_weights, past_key_value
Returns the attention output, attention weight, and past key-value pairs.
This module implements scaling dot product self-attention, which is a key component in the Transformer architecture .