一、引言
KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。
二、技术与原理简介
1.Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑
其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。,
2.Kolmogorov-Arnold 网络 (KAN)
Kolmogorov-Arnold 表示可以写成矩阵形式
其中
我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:
其中。
定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是
相反,多层感知器由线性层和非线错:
KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。
三、代码详解
A. 代码详解
1. 导入库与设置随机种子
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from .LBFGS import LBFGS
seed = 0
torch.manual_seed(seed)
-
功能: 导入必要的库,并设置全局随机种子确保结果可复现。
-
注释: 使用
LBFGS
自定义优化器,需确保该模块存在。随机种子固定为0。
2. MLP类定义(初始化与设备转移)
class MLP(nn.Module):
def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
super(MLP, self).__init__()
torch.manual_seed(seed)
linears = []
self.width = width
self.depth = depth = len(width) - 1
for i in range(depth):
linears.append(nn.Linear(width[i], width[i+1]))
self.linears = nn.ModuleList(linears)
self.act_fun = torch.nn.SiLU()
self.save_act = save_act
self.acts = None
self.cache_data = None
self.device = device
self.to(device)
def to(self, device):
super(MLP, self).to(device)
self.device = device
return self
-
功能: 初始化MLP网络结构,包含线性层列表和激活函数。
-
参数:
-
width
: 各层神经元数量(如[2,10,10,1]
)。 -
act
: 激活函数(默认SiLU)。 -
save_act
: 是否保存各层激活值。
-
-
注释: 使用
ModuleList
管理线性层,便于参数管理。to
方法确保模型在指定设备上运行。
3. 前向传播与激活保存
def forward(self, x):
self.cache_data = x
self.acts = []
self.acts_scale = []
self.wa_forward = []
self.a_forward = []
for i in range(self.depth):
if self.save_act:
act = x.clone()
act_scale = torch.std(x, dim=0)
wa_forward = act_scale[None, :] * self.linears[i].weight
self.acts.append(act)
if i > 0:
self.acts_scale.append(act_scale)
self.wa_forward.append(wa_forward)
x = self.linears[i](x)
if i < self.depth - 1:
x = self.act_fun(x)
else:
if self.save_act:
act_scale = torch.std(x, dim=0)
self.acts_scale.append(act_scale)
return x
-
功能: 执行前向传播,并保存各层的激活值和相关统计量。
-
流程:
-
缓存输入数据。
-
遍历各层,保存激活值及其标准差。
-
计算权重与激活的乘积(
wa_forward
),用于后续分析。
-
-
注释:
save_act
控制是否保存中间激活值,影响后续正则化和可视化。
4. 重要性评分计算(反向传播)
def attribute(self):
if self.acts == None:
self.get_act()
node_scores = []
edge_scores = []
node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
node_scores.append(node_score)
for l in range(self.depth,0,-1):
edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4))
edge_scores.append(edge_score)
node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
node_scores.append(node_score)
self.node_scores = list(reversed(node_scores))
self.edge_scores = list(reversed(edge_scores))
self.wa_backward = self.edge_scores
-
功能: 通过反向传播计算各层节点和边的重要性分数。
-
流程:
-
从输出层开始,逐层计算边的重要性(
edge_score
)。 -
根据边的重要性聚合得到节点的重要性(
node_score
)。
-
-
注释: 用于可视化或结构化正则化,分数反映参数对输出的影响。
5. 网络结构可视化
def plot(self, beta=3, scale=1., metric='w'):
# 设置绘图参数
fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
# 绘制节点
for j in range(len(shp)):
# ... 节点绘制代码
# 绘制边(根据metric选择权重)
for ii in range(len(linears)):
# ... 边绘制代码
ax.axis('off')
-
功能: 可视化MLP结构,边透明度反映权重大小。
-
参数:
-
metric
: 选择权重('w')、前向激活('act')或反向重要性('fa')。
-
-
注释: 使用Matplotlib绘制,边的颜色和透明度编码权重信息。
6. 正则化计算
def reg(self, reg_metric, lamb_l1, lamb_entropy):
if reg_metric == 'w':
acts_scale = self.w
# ... 其他metric处理
reg_ = 0.
for i in range(len(acts_scale)):
# 计算L1和熵正则项
return reg_
-
功能: 计算正则化损失,包含L1和熵惩罚。
-
参数:
-
reg_metric
: 正则化依据(权重、激活等)。 -
lamb_l1
: L1正则系数。 -
lamb_entropy
: 熵正则系数。
-
-
注释: 促进权重稀疏性和分布均匀性,防止过拟合。
7. 模型训练
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., ...):
# 初始化优化器
if opt == "LBFGS":
optimizer = LBFGS(...)
# 训练循环
for _ in pbar:
# 前向传播、损失计算、反向传播
# 更新参数并记录指标
-
功能: 训练模型,支持多种优化器和正则化。
-
参数:
-
opt
: 优化器类型(Adam或LBFGS)。 -
lamb
: 正则化强度。
-
-
注释: 使用闭包函数处理LBFGS优化,支持批量训练和指标记录。
8. 结构优化方法
def swap(self, l, i1, i2):
# 交换层内神经元权重
def auto_swap_l(self, l):
# 自动调整层内神经元顺序减少连接成本
def auto_swap(self):
# 对所有层执行自动调整
-
功能: 通过交换神经元位置优化网络结构,减少连接成本。
-
注释:
connection_cost
计算基于坐标的距离加权权重,交换神经元以最小化该成本。
9. 其他辅助方法
-
get_act: 获取或计算激活值。
-
connection_cost: 计算网络连接成本(基于权重和虚拟坐标)。
-
tree: 生成树状结构图(需外部函数支持)。
B. 完整代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from .LBFGS import LBFGS
seed = 0
torch.manual_seed(seed)
class MLP(nn.Module):
def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
super(MLP, self).__init__()
torch.manual_seed(seed)
linears = []
self.width = width
self.depth = depth = len(width) - 1
for i in range(depth):
linears.append(nn.Linear(width[i], width[i+1]))
self.linears = nn.ModuleList(linears)
#if activation == 'silu':
self.act_fun = torch.nn.SiLU()
self.save_act = save_act
self.acts = None
self.cache_data = None
self.device = device
self.to(device)
def to(self, device):
super(MLP, self).to(device)
self.device = device
return self
def get_act(self, x=None):
if isinstance(x, dict):
x = x['train_input']
if x == None:
if self.cache_data != None:
x = self.cache_data
else:
raise Exception("missing input data x")
save_act = self.save_act
self.save_act = True
self.forward(x)
self.save_act = save_act
@property
def w(self):
return [self.linears[l].weight for l in range(self.depth)]
def forward(self, x):
# cache data
self.cache_data = x
self.acts = []
self.acts_scale = []
self.wa_forward = []
self.a_forward = []
for i in range(self.depth):
if self.save_act:
act = x.clone()
act_scale = torch.std(x, dim=0)
wa_forward = act_scale[None, :] * self.linears[i].weight
self.acts.append(act)
if i > 0:
self.acts_scale.append(act_scale)
self.wa_forward.append(wa_forward)
x = self.linears[i](x)
if i < self.depth - 1:
x = self.act_fun(x)
else:
if self.save_act:
act_scale = torch.std(x, dim=0)
self.acts_scale.append(act_scale)
return x
def attribute(self):
if self.acts == None:
self.get_act()
node_scores = []
edge_scores = []
# back propagate from the last layer
node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
node_scores.append(node_score)
for l in range(self.depth,0,-1):
edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4))
edge_scores.append(edge_score)
# this might be improper for MLPs (although reasonable for KANs)
node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
#print(self.width[l])
node_scores.append(node_score)
self.node_scores = list(reversed(node_scores))
self.edge_scores = list(reversed(edge_scores))
self.wa_backward = self.edge_scores
def plot(self, beta=3, scale=1., metric='w'):
# metric = 'w', 'act' or 'fa'
if metric == 'fa':
self.attribute()
depth = self.depth
y0 = 0.5
fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
shp = self.width
min_spacing = 1/max(self.width)
for j in range(len(shp)):
N = shp[j]
for i in range(N):
plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black')
plt.ylim(-0.1*y0,y0*depth+0.1*y0)
plt.xlim(-0.02,1.02)
linears = self.linears
for ii in range(len(linears)):
linear = linears[ii]
p = linear.weight
p_shp = p.shape
if metric == 'w':
pass
elif metric == 'act':
p = self.wa_forward[ii]
elif metric == 'fa':
p = self.wa_backward[ii]
else:
raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric))
for i in range(p_shp[0]):
for j in range(p_shp[1]):
plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [y0*(ii+1),y0*ii], lw=0.5*scale, alpha=np.tanh(beta*np.abs(p[i,j].cpu().detach().numpy())), color="blue" if p[i,j]>0 else "red")
ax.axis('off')
def reg(self, reg_metric, lamb_l1, lamb_entropy):
if reg_metric == 'w':
acts_scale = self.w
if reg_metric == 'act':
acts_scale = self.wa_forward
if reg_metric == 'fa':
acts_scale = self.wa_backward
if reg_metric == 'a':
acts_scale = self.acts_scale
if len(acts_scale[0].shape) == 2:
reg_ = 0.
for i in range(len(acts_scale)):
vec = acts_scale[i]
vec = torch.abs(vec)
l1 = torch.sum(vec)
p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1)
p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1)
entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col)
elif len(acts_scale[0].shape) == 1:
reg_ = 0.
for i in range(len(acts_scale)):
vec = acts_scale[i]
vec = torch.abs(vec)
l1 = torch.sum(vec)
p = vec / (torch.sum(vec) + 1)
entropy = - torch.sum(p * torch.log2(p + 1e-4))
reg_ += lamb_l1 * l1 + lamb_entropy * entropy
return reg_
def get_reg(self, reg_metric, lamb_l1, lamb_entropy):
return self.reg(reg_metric, lamb_l1, lamb_entropy)
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1,
metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None):
if lamb > 0. and not self.save_act:
print('setting lamb=0. If you want to set lamb > 0, set =True')
old_save_act = self.save_act
if lamb == 0.:
self.save_act = False
pbar = tqdm(range(steps), desc='description', ncols=100)
if loss_fn == None:
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
else:
loss_fn = loss_fn_eval = loss_fn
if opt == "Adam":
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
elif opt == "LBFGS":
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
results = {}
results['train_loss'] = []
results['test_loss'] = []
results['reg'] = []
if metrics != None:
for i in range(len(metrics)):
results[metrics[i].__name__] = []
if batch == -1 or batch > dataset['train_input'].shape[0]:
batch_size = dataset['train_input'].shape[0]
batch_size_test = dataset['test_input'].shape[0]
else:
batch_size = batch
batch_size_test = batch
global train_loss, reg_
def closure():
global train_loss, reg_
optimizer.zero_grad()
pred = self.forward(dataset['train_input'][train_id].to(self.device))
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
if self.save_act:
if reg_metric == 'fa':
self.attribute()
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy)
else:
reg_ = torch.tensor(0.)
objective = train_loss + lamb * reg_
objective.backward()
return objective
for _ in pbar:
if _ == steps-1 and old_save_act:
self.save_act = True
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
if opt == "LBFGS":
optimizer.step(closure)
if opt == "Adam":
pred = self.forward(dataset['train_input'][train_id].to(self.device))
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
if self.save_act:
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy)
else:
reg_ = torch.tensor(0.)
loss = train_loss + lamb * reg_
optimizer.zero_grad()
loss.backward()
optimizer.step()
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device))
if metrics != None:
for i in range(len(metrics)):
results[metrics[i].__name__].append(metrics[i]().item())
results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
results['reg'].append(reg_.cpu().detach().numpy())
if _ % log == 0:
if display_metrics == None:
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
else:
string = ''
data = ()
for metric in display_metrics:
string += f' {metric}: %.2e |'
try:
results[metric]
except:
raise Exception(f'{metric} not recognized')
data += (results[metric][-1],)
pbar.set_description(string % data)
return results
@property
def connection_cost(self):
with torch.no_grad():
cc = 0.
for linear in self.linears:
t = torch.abs(linear.weight)
def get_coordinate(n):
return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n)
in_dim = t.shape[0]
x_in = get_coordinate(in_dim)
out_dim = t.shape[1]
x_out = get_coordinate(out_dim)
dist = torch.abs(x_in[:,None] - x_out[None,:])
cc += torch.sum(dist * t)
return cc
def swap(self, l, i1, i2):
def swap_row(data, i1, i2):
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
def swap_col(data, i1, i2):
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
swap_row(self.linears[l-1].weight.data, i1, i2)
swap_row(self.linears[l-1].bias.data, i1, i2)
swap_col(self.linears[l].weight.data, i1, i2)
def auto_swap_l(self, l):
num = self.width[l]
for i in range(num):
ccs = []
for j in range(num):
self.swap(l,i,j)
self.get_act()
self.attribute()
cc = self.connection_cost.detach().clone()
ccs.append(cc)
self.swap(l,i,j)
j = torch.argmin(torch.tensor(ccs))
self.swap(l,i,j)
def auto_swap(self):
depth = self.depth
for l in range(1, depth):
self.auto_swap_l(l)
def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
if x == None:
x = self.cache_data
plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose)
四、总结与思考
KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。
1. KAN网络架构
-
关键设计:可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。
-
示例结构:输入层 → 隐藏层:每个输入节点通过单变量函数
连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数
组合得到输出。
2. 优势与特点
-
高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。
-
可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。
-
灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。
3. 挑战与局限
-
计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。
-
泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。
-
训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。
4. 应用场景
-
科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。
-
可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。
-
符号回归:从数据中自动发现数学公式(如物理定律)。
5. 与传统MLP的对比
6. 研究进展
-
近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。
-
开源实现:已有PyTorch等框架的初步实现。
【作者声明】
本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。
【关注我们】
如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!