一、引言
KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。
二、技术与原理简介
1.Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑
其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。,
2.Kolmogorov-Arnold 网络 (KAN)
Kolmogorov-Arnold 表示可以写成矩阵形式
其中
我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:
其中。
定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是
相反,多层感知器由线性层和非线错:
KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。
三、代码详解
我们已经展示了如何使用 plot() 方法可视化 KAN。如果想保存 KAN 训练过程的动态变化,只需在 train() 方法中传入参数 save_video = True(并设置一些视频相关参数)
from kan import KAN, create_dataset
import torch
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[4,2,1,1], grid=3, k=3, seed=0)
f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)
dataset = create_dataset(f, n_var=4, train_num=3000)
# train the model
#model.train(dataset, opt="LBFGS", steps=20, lamb=1e-3, lamb_entropy=2.);
model.train(dataset, opt="LBFGS", steps=50, lamb=5e-5, lamb_entropy=2., save_video=True, beta=10,
in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],
out_vars=[r'${\rm exp}({\rm sin}(x_1^2+x_2^2)+{\rm sin}(x_3^2+x_4^2))$'],
video_name='video', fps=5);
-
from kan import KAN, create_dataset
:从名为kan
的模块中导入KAN
类和create_dataset
函数。KAN
是一个用于建模和训练的类,而create_dataset
用于生成数据集。 -
import torch
:导入 PyTorch 模块,用于张量操作和深度学习相关的计算。
-
创建一个
KAN
模型实例。 -
参数说明:
-
width=[4,2,1,1]
:定义模型的网络结构。width
是一个列表,表示每一层的神经元数量。这里输入层有 4 个神经元,隐藏层有 2 个神经元,输出层有 1 个神经元,最后还有一个额外的输出层(可能用于某种特定的输出处理)。 -
grid=3
:定义网格的间隔数量。在KAN
模型中,网格用于离散化输入空间,grid=3
表示将输入空间划分为 3 个区间。 -
k=3
:定义使用的样条函数的阶数。k=3
表示使用三次样条函数(cubic spline)。 -
seed=0
:设置随机种子,用于确保结果的可重复性。
-
-
定义一个目标函数
f
,用于生成数据集的标签。 -
输入
x
是一个张量,形状为(n_samples, 4)
,表示有 4 个输入变量。 -
函数的计算过程:
-
x[:,[0]]**2 + x[:,[1]]**2
:计算输入变量x1
和x2
的平方和。 -
x[:,[2]]**2 + x[:,[3]]**2
:计算输入变量x3
和x4
的平方和。 -
torch.sin(torch.pi*(...))
:对平方和取正弦函数,并乘以 π。 -
将两个正弦结果相加后除以 2,再取指数函数
torch.exp
。
-
-
这个函数的输出是一个标量值,用于作为数据集的目标值。
-
使用
create_dataset
函数生成数据集。 -
参数说明:
-
f
:目标函数,用于生成数据集的标签。 -
n_var=4
:表示输入变量的数量为 4。 -
train_num=3000
:生成 3000 个训练样本。
-
-
调用
model.fit
方法对模型进行训练。 -
参数说明:
-
dataset
:输入的数据集。 -
opt="LBFGS"
:指定优化器为 L-BFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno),一种高效的准牛顿优化方法。 -
steps=50
:训练的迭代步数为 50 步。 -
lamb=5e-5
:正则化参数,用于防止过拟合。 -
lamb_entropy=2.
:熵正则化参数,可能用于某种正则化机制。 -
save_video=True
:保存训练过程的视频。 -
beta=10
:一个超参数,可能用于某种特定的调整。 -
in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$']
:输入变量的名称,用于可视化或记录。 -
out_vars=[r'${\rm exp}({\rm sin}(x_1^2+x_2^2)+{\rm sin}(x_3^2+x_4^2))$']
:输出变量的名称,用于可视化或记录。 -
video_name='video'
:保存的视频文件名。 -
fps=5
:视频的帧率,每秒 5 帧。
-
四、总结与思考
KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。
1. KAN网络架构
-
关键设计:可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。
-
示例结构:输入层 → 隐藏层:每个输入节点通过单变量函数
连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数
组合得到输出。
2. 优势与特点
-
高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。
-
可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。
-
灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。
3. 挑战与局限
-
计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。
-
泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。
-
训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。
4. 应用场景
-
科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。
-
可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。
-
符号回归:从数据中自动发现数学公式(如物理定律)。
5. 与传统MLP的对比
6. 研究进展
-
近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。
-
开源实现:已有PyTorch等框架的初步实现。
【作者声明】
本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。
【关注我们】
如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!