低秩分解(Low-rank Decomposition)是一种矩阵分解技术,旨在将一个矩阵分解为两个或多个矩阵的乘积,从而将高维数据压缩为低维表示,以减少参数量。在大模型中,权重矩阵通常非常大,尤其是在全连接层和注意力机制中,运用低秩分解可以减少参数量和计算量。
权重矩阵 W W W 的大小为 m × n m \times n m×n。通过低秩分解,可以将 W W W 分解为两个较小的矩阵 W 1 , W 2 W_1, W_2 W1,W2 的乘积:
W ≈ W 1 × W 2 W \approx W_1 \times W_2 W≈W1×W2
其中 W 1 W_1 W1 的大小为 m × k m \times k m×k, W 2 W_2 W2 的大小为 k × n k \times n k×n, k k k 是矩阵 W W W 秩(rank),通常 k k k 远小于 m i n ( m , n ) min(m, n) min(m,n)。
常见的低秩分解方法包括**奇异值分解(SVD)**和矩阵分解(如CP分解、Tucker分解、BTD 分解, https://zhuanlan.zhihu.com/p/490455377)。
1 低秩分解压缩模型
基本步骤:
- 选择需要分解的权重矩阵:通常选择全连接层和注意力机制中的权重矩阵进行低秩分解。
- 确定秩 k k k:选择合适的秩,通常通过实验或根据问题的具体要求来确定。
- 分解矩阵:使用 SVD 或其他低秩分解方法对权重矩阵进行分解。
- 重建权重矩阵:根据分解结果重建低秩近似矩阵。
- 微调模型:在低秩分解后,通常需要对模型进行微调,以恢复模型的性能。
2 实验
使用 pytorch 对 BERT 模型中的全连接层进行低秩分解,达到压缩模型的目的。
2.1 安装依赖
pip install transformers torch
2.2 加载预训练的BERT模型
import torch
from transformers import BertModel, BertConfig
# 加载预训练的BERT模型
model_name = 'bert-base-uncased'
config = BertConfig.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, config=config)
# 打印模型结构,找到需要分解的全连接层
print(model)
2.3 选择需要低秩分解的全连接层
对BERT的第一个注意力头中的线性变换层进行低秩分解:
# 选择第一个注意力头中的线性变换层
layer = model.encoder.layer[0].attention.self.query
print(layer)
2.4 对全连接层进行低秩分解
使用SVD(奇异值分解)对全连接层的权重矩阵进行分解。将权重矩阵分解为两个低秩矩阵的乘积。
import torch.nn.functional as F
# 获取全连接层的权重矩阵
weight = layer.weight.data
# 对权重矩阵进行SVD分解
U, S, V = torch.svd(weight)
# 选择低秩分解的秩(rank)
rank = 10 # 选择一个较小的秩
# 截取前rank个奇异值和对应的奇异向量
U_low = U[:, :rank]
S_low = torch.diag(S[:rank])
V_low = V[:, :rank]
# 计算低秩分解后的权重矩阵
weight_low = U_low @ S_low @ V_low.t()
# 将低秩分解后的权重矩阵赋值回原层
layer.weight.data = weight_low
2.5 分解后的微调
分解后的模型可能需要进一步微调,以恢复部分损失的性能。
2.6 验证分解效果
通过比较分解前后的模型输出,来验证低秩分解的效果。
# 创建一个输入张量
input_ids = torch.tensor([[31, 51, 99, 1]])
attention_mask = torch.tensor([[1, 1, 1, 1]])
# 获取分解前的输出
with torch.no_grad():
output_before = model(input_ids=input_ids, attention_mask=attention_mask)
# 对模型进行低秩分解
layer = model.encoder.layer[0].attention.self.query
weight = layer.weight.data
U, S, V = torch.svd(weight)
rank = 10
U_low = U[:, :rank]
S_low = torch.diag(S[:rank])
V_low = V[:, :rank]
weight_low = U_low @ S_low @ V_low.t()
layer.weight.data = weight_low
# 获取分解后的输出
with torch.no_grad():
output_after = model(input_ids=input_ids, attention_mask=attention_mask)
# 比较分解前后的输出
print("Output before decomposition:", output_before)
print("Output after decomposition:", output_after)