一、引言
KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。
二、技术与原理简介
1.Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑
其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。,
2.Kolmogorov-Arnold 网络 (KAN)
Kolmogorov-Arnold 表示可以写成矩阵形式
其中
我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:
其中。
定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是
相反,多层感知器由线性层和非线错:
KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。
三、代码详解
experiment.py为KAN模型训练的子库代码,包括模型的训练、剪枝和优化过程,同时还包括一个计算帕累托前沿(Pareto Frontier)的辅助函数,下述为详细注解。
A. 代码详解
1. runner1 函数:模型训练与优化流程
def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1):
功能概述
runner1 是代码的核心函数,用于执行一个多阶段的模型训练流程,包括初始化、训练、剪枝(pruning)和精炼(refinement),并在每个阶段收集评估结果。它返回一个包含训练过程中各种指标的字典。
参数解释
- width: 定义 KAN 模型的宽度(可能是网络的层结构或神经元数量)。
- dataset: 训练和评估使用的数据集。
- grids: 一个列表(如 [5, 10, 20]),表示不同阶段使用的网格大小,用于模型精炼。
- steps: 每次训练的迭代步数,默认为 20。
- lamb: 正则化参数(默认 0.001),用于控制模型复杂度。
- prune_round: 剪枝循环的次数,默认为 3。
- refine_round: 精炼循环的次数,默认为 3。
- edge_th: 剪枝时边的阈值(默认 0.01),小于此值的边可能被移除。
- node_th: 剪枝时节点的阈值(默认 0.01),小于此值的节点可能被移除。
- metrics: 可选的额外评估指标(函数列表),用于计算自定义指标。
- seed: 随机种子(默认 1),用于确保实验可重复性。
2. 内部逻辑
(1) 结果存储初始化
result = {}
result['test_loss'] = []
result['c'] = []
result['G'] = []
result['id'] = []
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__] = []
- 创建一个字典 result 用于存储训练过程中的结果。
- 默认存储的指标包括:
- test_loss: 测试集损失。
- c: 边的数量(n_edge)。
- G: 网格数量(n_grid)。
- id: 模型状态的唯一标识符。
- 如果提供了额外的 metrics,则为每个指标创建一个空列表。
(2) 收集结果的辅助函数
def collect(evaluation):
result['test_loss'].append(evaluation['test_loss'])
result['c'].append(evaluation['n_edge'])
result['G'].append(evaluation['n_grid'])
result['id'].append(f'{model.round}.{model.state_id}')
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__].append(metrics[i](model, dataset).item())
- collect 函数从 evaluation 字典中提取评估结果并追加到 result 中。
- evaluation 是一个字典,包含模型评估后的指标(如 test_loss 和 n_edge)。
- 如果有额外指标,会调用传入的 metrics 函数计算并存储结果。
(3) 主循环:剪枝和精炼
for i in range(prune_round):
# train and prune
if i == 0:
model = KAN(width=width, grid=grids[0], seed=seed)
else:
model = model.rewind(f'{i-1}.{2*i}')
模型初始化或回溯:
- 如果是第一次循环(i == 0),用 width、第一个网格大小 grids[0] 和 seed 初始化一个新的 KAN 模型。
- 否则,使用 rewind 方法将模型恢复到之前的某个状态(可能是上一个剪枝阶段的检查点),状态标识为 f'{i-1}.{2*i}'。
model.fit(dataset, steps=steps, lamb=lamb)
model = model.prune(edge_th=edge_th, node_th=node_th)
evaluation = model.evaluate(dataset)
collect(evaluation)
训练和剪枝:
- 用 fit 方法训练模型,训练 steps 步,加入正则化参数 lamb。
- 用 prune 方法剪枝,根据 edge_th 和 node_th 移除不重要的边和节点。
- 调用 evaluate 评估模型,得到 evaluation 结果,并用 collect 存储。
for j in range(refine_round):
model = model.refine(grids[j])
model.fit(dataset, steps=steps)
evaluation = model.evaluate(dataset)
collect(evaluation)
精炼循环:
- 在每次剪枝后,运行 refine_round 次精炼。
- 用 refine 方法调整模型,采用 grids[j] 指定的网格大小。
- 再次训练模型(fit),然后评估并收集结果。
(4) 结果转换与返回
for key in list(result.keys()):
result[key] = np.array(result[key])
return result
- 将 result 字典中的所有列表转换为 NumPy 数组。
- 返回包含所有指标的 result 字典。
3. pareto_frontier 函数:计算帕累托前沿
def pareto_frontier(x, y):
pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0]
x_pf = x[pf_id]
y_pf = y[pf_id]
return x_pf, y_pf, pf_id
功能概述
pareto_frontier 计算输入的两个数组 x 和 y 的帕累托前沿,用于分析两个指标之间的权衡关系(例如模型复杂度和性能)。
参数解释
- x: 表示第一个指标的数组(如模型复杂度)。
- y: 表示第二个指标的数组(如测试损失)。
- 返回值:
- x_pf: 帕累托前沿上的 x 值。
- y_pf: 帕累托前沿上的 y 值。
- pf_id: 帕累托前沿点的索引。
4. 内部逻辑
(1) 帕累托前沿的定义
- 帕累托前沿是一组点,其中每个点在 x 和 y 上都不被其他点同时支配(即不存在另一个点在 x 和 y 上都优于它)。
(2) 计算过程
pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0]
x_pf = x[pf_id]
y_pf = y[pf_id]
- (x[:,None] <= x[None,:]): 比较所有 x 值,生成一个布尔矩阵,表示每个点是否在 x 维度上小于等于其他点。
- (y[:,None] <= y[None,:]): 类似地,比较所有 y 值。
- *: 将两个布尔矩阵相乘,表示在 x 和 y 上同时小于等于的情况。
- np.sum(..., axis=0): 对每一列求和,计算每个点被其他点支配的次数。
- == 1: 选择那些只被自身支配的点(每个点总是小于等于自己,因此和为 1 表示未被其他点支配)。
- np.where(...)[0]: 获取满足条件的点的索引。
(3) 提取结果
- 根据 pf_id 提取对应的 x 和 y 值,返回前沿点及其索引。
B. 完整代码
import torch
from .MultKAN import *
def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1):
result = {}
result['test_loss'] = []
result['c'] = []
result['G'] = []
result['id'] = []
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__] = []
def collect(evaluation):
result['test_loss'].append(evaluation['test_loss'])
result['c'].append(evaluation['n_edge'])
result['G'].append(evaluation['n_grid'])
result['id'].append(f'{model.round}.{model.state_id}')
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__].append(metrics[i](model, dataset).item())
for i in range(prune_round):
# train and prune
if i == 0:
model = KAN(width=width, grid=grids[0], seed=seed)
else:
model = model.rewind(f'{i-1}.{2*i}')
model.fit(dataset, steps=steps, lamb=lamb)
model = model.prune(edge_th=edge_th, node_th=node_th)
evaluation = model.evaluate(dataset)
collect(evaluation)
for j in range(refine_round):
model = model.refine(grids[j])
model.fit(dataset, steps=steps)
evaluation = model.evaluate(dataset)
collect(evaluation)
for key in list(result.keys()):
result[key] = np.array(result[key])
return result
def pareto_frontier(x, y):
# 找到Pareto前沿的点的索引
pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0]
x_pf = x[pf_id]
y_pf = y[pf_id]
return x_pf, y_pf, pf_id
四、总结与思考
KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。
1. KAN网络架构
-
关键设计:可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。
-
示例结构:输入层 → 隐藏层:每个输入节点通过单变量函数
连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数
组合得到输出。
2. 优势与特点
-
高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。
-
可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。
-
灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。
3. 挑战与局限
-
计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。
-
泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。
-
训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。
4. 应用场景
-
科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。
-
可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。
-
符号回归:从数据中自动发现数学公式(如物理定律)。
5. 与传统MLP的对比
6. 研究进展
-
近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。
-
开源实现:已有PyTorch等框架的初步实现。
【作者声明】
本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。
【关注我们】
如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!