一、引言
KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。
二、技术与原理简介
1.Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑
其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。,
2.Kolmogorov-Arnold 网络 (KAN)
Kolmogorov-Arnold 表示可以写成矩阵形式
其中
我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:
其中。
定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是
相反,多层感知器由线性层和非线错:
KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。
三、代码详解
这段代码提供了一套完整的B样条处理工具,包括基函数的计算、曲线的生成与评估、以及从曲线恢复系数的功能,适用于需要进行平滑插值或曲线拟合的场景。
A. 代码详解
1. B_batch
函数
import torch
def B_batch(x, grid, k=0, extend=True, device='cpu'):
'''
evaludate x on B-spline bases
Args:
-----
x : 2D torch.tensor
inputs, shape (number of splines, number of samples)
grid : 2D torch.tensor
grids, shape (number of splines, number of grid points)
k : int
the piecewise polynomial order of splines.
extend : bool
If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
device : str
devicde
Returns:
--------
spline values : 3D torch.tensor
shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
Example
-------
>>> from kan.spline import B_batch
>>> x = torch.rand(100,2)
>>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
>>> B_batch(x, grid, k=3).shape
'''
x = x.unsqueeze(dim=2)
grid = grid.unsqueeze(dim=0)
if k == 0:
value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
else:
B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1)
value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
# in case grid is degenerate
value = torch.nan_to_num(value)
return value
功能:计算给定输入在B样条基函数上的值。
参数:
x
:一个二维的torch张量,表示输入,形状为(样条数量,样本数量)。grid
:一个二维的torch张量,表示网格,形状为(样条数量,网格点数量)。k
:整数,表示样条的分段多项式阶数。extend
:布尔值,指示是否在两端扩展k个点。默认为True。device
:字符串,表示设备(如’cpu’或’cuda’)。
返回值:返回一个三维的torch张量,表示在B样条基函数上的值,形状为(批量,输入维度,G+k),其中G是网格区间的数量,k是样条阶数。
2. coef2curve
函数
def coef2curve(x_eval, grid, coef, k, device="cpu"):
'''
converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
Args:
-----
x_eval : 2D torch.tensor
shape (batch, in_dim)
grid : 2D torch.tensor
shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
coef : 3D torch.tensor
shape (in_dim, out_dim, G+k)
k : int
the piecewise polynomial order of splines.
device : str
devicde
Returns:
--------
y_eval : 3D torch.tensor
shape (number of samples, in_dim, out_dim)
'''
b_splines = B_batch(x_eval, grid, k=k)
y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
return y_eval
功能:将B样条系数转换为B样条曲线,并在给定的输入上评估。
参数:
x_eval
:一个二维的torch张量,形状为(批量,输入维度)。grid
:一个二维的torch张量,形状为(输入维度,G+2k)。coef
:一个三维的torch张量,形状为(输入维度,输出维度,G+k)。k
:整数,表示样条的分段多项式阶数。device
:字符串,表示设备。
返回值:返回一个三维的torch张量,表示在样条曲线上的评估结果,形状为(样本数量,输入维度,输出维度)。
3. curve2coef
函数
def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
'''
converting B-spline curves to B-spline coefficients using least squares.
Args:
-----
x_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
y_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
grid : 2D torch.tensor
shape (in_dim, grid+2*k)
k : int
spline order
lamb : float
regularized least square lambda
Returns:
--------
coef : 3D torch.tensor
shape (in_dim, out_dim, G+k)
'''
batch = x_eval.shape[0]
in_dim = x_eval.shape[1]
out_dim = y_eval.shape[2]
n_coef = grid.shape[1] - k - 1
mat = B_batch(x_eval, grid, k)
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
device = mat.device
#coef = torch.linalg.lstsq(mat, y_eval,
#driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
A = XtX + lamb * identity
B = Xty
coef = (A.pinverse() @ B)[:,:,:,0]
return coef
功能:通过最小二乘法将B样条曲线转换为B样条系数。
参数:
x_eval
:一个二维的torch张量,形状为(输入维度,输出维度,样本数量)。y_eval
:一个二维的torch张量,形状为(输入维度,输出维度,样本数量)。grid
:一个二维的torch张量,形状为(输入维度,网格+2*k)。k
:整数,表示样条阶数。lamb
:浮点数,正则化最小二乘法的λ值。
返回值:返回一个三维的torch张量,表示B样条系数,形状为(输入维度,输出维度,G+k)。
4. extend_grid
函数
def extend_grid(grid, k_extend=0):
'''
extend grid
'''
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
for i in range(k_extend):
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
return grid
功能:扩展网格。
参数:
grid
:一个二维的torch张量,表示原始网格。k_extend
:整数,表示要扩展的点数。
返回值:返回一个扩展后的网格。
B. 完整代码
import torch
def B_batch(x, grid, k=0, extend=True, device='cpu'):
'''
evaludate x on B-spline bases
Args:
-----
x : 2D torch.tensor
inputs, shape (number of splines, number of samples)
grid : 2D torch.tensor
grids, shape (number of splines, number of grid points)
k : int
the piecewise polynomial order of splines.
extend : bool
If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
device : str
devicde
Returns:
--------
spline values : 3D torch.tensor
shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
Example
-------
>>> from kan.spline import B_batch
>>> x = torch.rand(100,2)
>>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
>>> B_batch(x, grid, k=3).shape
'''
x = x.unsqueeze(dim=2)
grid = grid.unsqueeze(dim=0)
if k == 0:
value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
else:
B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1)
value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
# in case grid is degenerate
value = torch.nan_to_num(value)
return value
def coef2curve(x_eval, grid, coef, k, device="cpu"):
'''
converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
Args:
-----
x_eval : 2D torch.tensor
shape (batch, in_dim)
grid : 2D torch.tensor
shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
coef : 3D torch.tensor
shape (in_dim, out_dim, G+k)
k : int
the piecewise polynomial order of splines.
device : str
devicde
Returns:
--------
y_eval : 3D torch.tensor
shape (number of samples, in_dim, out_dim)
'''
b_splines = B_batch(x_eval, grid, k=k)
y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
return y_eval
def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
'''
converting B-spline curves to B-spline coefficients using least squares.
Args:
-----
x_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
y_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
grid : 2D torch.tensor
shape (in_dim, grid+2*k)
k : int
spline order
lamb : float
regularized least square lambda
Returns:
--------
coef : 3D torch.tensor
shape (in_dim, out_dim, G+k)
'''
batch = x_eval.shape[0]
in_dim = x_eval.shape[1]
out_dim = y_eval.shape[2]
n_coef = grid.shape[1] - k - 1
mat = B_batch(x_eval, grid, k)
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
device = mat.device
#coef = torch.linalg.lstsq(mat, y_eval,
#driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
A = XtX + lamb * identity
B = Xty
coef = (A.pinverse() @ B)[:,:,:,0]
return coef
def extend_grid(grid, k_extend=0):
'''
extend grid
'''
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
for i in range(k_extend):
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
return grid
四、总结与思考
KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。
1. KAN网络架构
-
关键设计:可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。
-
示例结构:输入层 → 隐藏层:每个输入节点通过单变量函数
连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数
组合得到输出。
2. 优势与特点
-
高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。
-
可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。
-
灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。
3. 挑战与局限
-
计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。
-
泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。
-
训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。
4. 应用场景
-
科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。
-
可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。
-
符号回归:从数据中自动发现数学公式(如物理定律)。
5. 与传统MLP的对比
6. 研究进展
-
近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。
-
开源实现:已有PyTorch等框架的初步实现。
【作者声明】
本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。
【关注我们】
如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!