Liger (Linkedin GPU Efficient Runtime) Kernel is a collection of Triton kernels
designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.
Liger Kernel用于加速训练,减小训练峰值显存,主要就是在transformers
包的基础上,从liger_kernel.transformers
包中,导入了一个apply_liger_kernel_to_xxx()
的函数,用于替换原始transformers
包写的各个组件,这些组件都是用triton
重写加速过的。
举个例子,下面是apply_liger_kernel_to_llama()
函数的内容:就是把原始modeling_xxx
的ROPE、RMS_Norm、Swiglu、Cross_Entorpy等模块给替换掉。
def apply_liger_kernel_to_llama(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
from transformers.models.llama import modeling_llama
if rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
config: PretrainedConfig = model.config
if hasattr(model, "model"):
# The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
base_model = model.model
elif hasattr(model, "transformer"):
# LlamaForQuestionAnswering uses "transformer" instead of "model"
base_model = model.transformer
else:
# Direct LlamaModel
base_model = model
torch_dtype = config.torch_dtype
if rms_norm:
base_model.norm = LigerRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
).to(torch_dtype)
for decoder_layer in base_model.layers:
if swiglu:
decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
if rms_norm:
decoder_layer.input_layernorm = LigerRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
).to(torch_dtype)
decoder_layer.post_attention_layernorm = LigerRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
).to(torch_dtype)
那替换掉的模型里面具体是怎么加速的呢?我们深入MLP来看一下:可以看到Liger实现的MLP在init部分没有任何修改,在forward的时候apply了自己定义的算子ops
,优化原始的gate_proj和up_proj的乘法操作。
在自己定义的算子中,就是导入triton
,自己重写forward
和backward
。
class LigerSiLUMulFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, a, b):
a, b, c = swiglu_forward(a, b)
ctx.save_for_backward(a, b)
return c
@staticmethod
@ensure_contiguous
def backward(ctx, dc):
a, b = ctx.saved_tensors
a, b = swiglu_backward(a, b, dc)
return a, b