【KAN】KAN神经网络学习训练营(12)——utils.py

一、引言

        KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。


二、技术与原理简介

        1.Kolmogorov-Arnold 表示定理

         Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑ff:[0,1]^{^{n}}\rightarrow \mathbb{R}

f \left( x \right)=f \left( x_{1}, \cdots,x_{n} \right)= \sum_{q=1}^{2n+1} \Phi_{q} \left( \sum_{p=1}^{n} \phi_{q,p} \left( x_{p} \right) \right)

        其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。\boldsymbol{\phi_{q,p}:[0,1]\to\mathbb{R}}\boldsymbol{\Phi_{q}:\mathbb{R}\to\mathbb{R}(2n+1)}

        2.Kolmogorov-Arnold 网络 (KAN)

        Kolmogorov-Arnold 表示可以写成矩阵形式

f(x)=\mathbf{\Phi_{out}}\mathsf{o}\mathbf{\Phi_{in}}\mathsf{o}{}x

其中

\mathbf{\Phi}_{\mathrm{in}}=\begin{pmatrix}\phi_{1,1}(\cdot)&\cdots&\phi_{1,n }(\cdot)\\ \vdots&&\vdots\\ \phi_{2n+1,1}(\cdot)&\cdots&\phi_{2n+1,n}(\cdot)\end{pmatrix}

\quad\mathbf{ \Phi}_{\mathrm{out}}=\left(\Phi_{1}(\cdot)\quad\cdots\quad\Phi_{2n+1}(\cdot)\right)

        我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:\mathbf{\Phi_{in}} \mathbf{\Phi_{out}} \mathbf{\Phi_{n_{in}n_{out}}}

其中\boldsymbol{n_{\text{in}}=n,n_{\text{out}}=2n+1\Phi_{\text{out}}n_{\text{in}}=2n+1,n_{\text{out}}=1}

        定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是Ll^{th} \Phi_{l} \left( n_{l+1},n_{l} \right)

\mathbf{KAN(x)}=\mathbf{\Phi_{L-1}}\circ\cdots\circ\mathbf{\Phi_{1}}\circ \mathbf{\Phi_{0}}\circ\mathbf{x}

        相反,多层感知器由线性层和非线错:\mathbf{W}_{l^{\sigma}}

\text{MLP}(\mathbf{x})=\mathbf{W}_{\textit{L-1}}\circ\sigma\circ\cdots\circ \mathbf{W}_{1}\circ\sigma\circ\mathbf{W}_{0}\circ\mathbf{x}

        KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。


三、代码详解

        整体来看,这部分代码实现了以下功能:

  • 符号函数库的构建与扩展:不仅定义了基本的符号函数,还提供了奇异性保护机制,并允许用户动态添加新的符号函数。

  • 数据集生成与预处理:既支持基于符号公式生成合成数据,也能根据现有数据进行训练/测试划分与归一化。

  • 参数拟合:通过网格搜索和线性回归方法拟合非线性函数中的仿射参数。

  • 自动求导工具:实现了批量计算雅可比和 Hessian 的方法,为模型敏感性分析和二阶优化提供支持。

  • 模型参数管理:提供了将模型参数展平和动态加载的工具,方便在求导时更新模型状态。

        这些功能为基于符号激活函数构建的神经网络模型提供了完整的数据、拟合、优化和解释工具,极大地增强了模型的灵活性和可解释性。

        A. 代码详解

1. 模块导入

import numpy as np
import torch
from sklearn.linear_model import LinearRegression
import sympy
import yaml
from sympy.utilities.lambdify import lambdify
import re
  • numpytorch:用于数值计算和张量操作,其中 torch 用于构建深度学习模型。

  • LinearRegression:来自 scikit-learn,用于后续在参数拟合中求解线性回归问题(拟合仿射系数)。

  • sympylambdify:用于符号数学和将符号表达式转换为可执行的数值函数。

  • yaml:可能用于配置文件读写。

  • re:正则表达式模块,用于字符串处理,在解析模型参数名称时会用到。


2. 奇异性保护函数

定义了一系列 lambda 表达式,旨在在计算过程中避免因除零或指数过大/过小而引发数值不稳定(奇异性问题)。

例如:

  • f_inv:对 1/x 进行保护。利用阈值 x_th = 1/y_th,在 |x| 小于阈值时返回一个平滑的近似值,否则采用 1/x。其返回值是一个元组,第一个元素为 x_th,第二个为计算结果。

  • f_inv2, f_inv3, f_inv4, f_inv5:类似 f_inv,但分别用于 1/x², 1/x³, 1/x⁴ 和 1/x⁵,对应不同的幂次处理方式。

  • f_sqrt:计算平方根,当输入较小时(小于阈值)返回线性近似,避免直接开根号可能出现数值问题。

  • f_power1d5f_invsqrtf_logf_tanf_arctanhf_arcsinf_arccosf_exp:分别对应 1.5 次方、1/√x、对数、正切、反双曲正切、反正弦、反余弦、指数函数,都采用了类似的策略,根据输入值大小在不同区域采用不同的计算方法或平滑近似,确保数值稳定。

这种设计方式使得在网络中使用这些函数时,可以避免在训练过程中由于输入值接近奇异点(例如 x 过小或过大)而导致的数值溢出或 NaN 值。


3. SYMBOLIC_LIB 符号函数库

SYMBOLIC_LIB = {
    'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
    'x^2': (lambda x: x**2, lambda x: x**2, 2, lambda x, y_th: ((), x**2)),
    ...
    'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))),
    'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))),
    ...
    'gaussian': (lambda x: torch.exp(-x**2), lambda x: sympy.exp(-x**2), 3, lambda x, y_th: ((), torch.exp(-x**2))),
    ...
}
  • 该字典将每个符号函数(例如 'x', 'x^2', '1/x', 'sin' 等)映射到一个四元组,其中包含:

    • torch 版本:用于数值计算(前向传播等)。

    • sympy 版本:用于符号运算和表达式分析。

    • 一个常数(例如 1, 2, 3 等):可能用于指示该函数的某种属性(例如复杂度或用于拟合时的权重)。

    • 奇异性保护版本:当需要避免奇异性时调用的函数(如 f_inv、f_log 等)。

这种设计允许在模型中同时使用数值计算和符号推理,便于解释和调试。


4. create_dataset 函数

def create_dataset(f, 
                   n_var=2, 
                   f_mode = 'col',
                   ranges = [-1,1],
                   train_num=1000, 
                   test_num=1000,
                   normalize_input=False,
                   normalize_label=False,
                   device='cpu',
                   seed=0):

作用与流程

  • 目的:利用给定的符号公式 f(可以是任意函数)生成合成数据集。数据集包括训练集和测试集。

  • 参数说明

    • n_var:输入变量的个数。

    • ranges:每个输入变量的取值范围(如果给定单一区间,则所有变量使用相同区间)。

    • train_numtest_num:分别为训练和测试样本数量。

    • normalize_inputnormalize_label:是否对输入和输出进行归一化处理。

    • f_mode:指明函数 f 的输入格式,'col' 表示数据按列输入,'row' 表示转置后输入。

  • 过程

    1. 根据随机种子生成训练和测试输入数据,每个输入按均匀分布从指定区间中采样。

    2. 根据 f_mode 计算标签。

    3. 如果标签或输入数据为一维,则增加维度。

    4. 可选的归一化处理。

    5. 最后将数据移动到指定的设备上并以字典形式返回。


5. fit_params 函数

def fit_params(x, y, fun, a_range=(-10,10), b_range=(-10,10), grid_number=101, iteration=3, verbose=True, device='cpu'):

作用与流程

  • 目的:给定一组输入 x 与目标输出 y,拟合一个形式为

    y≈c⋅f(a⋅x+b)+dy \approx c \cdot f(a \cdot x + b) + dy≈c⋅f(a⋅x+b)+d

    的模型,从而寻找最优的仿射参数 [a, b, c, d]。其中 fun 为符号函数(例如 torch.sin)。

  • 过程

    1. 利用网格搜索在 a 和 b 的范围内寻找使得决定系数 r2r^2r2 最高的组合。通过多次迭代(zoom in)来缩小搜索范围。

    2. 对每个网格点计算经过函数转换后的输出,并利用公式计算 r2r^2r2 值。

    3. 根据最佳 a、b 值,再利用线性回归(LinearRegression)拟合 c 与 d。

    4. 返回最佳参数向量 [a, b, c, d] 和对应的 r2r^2r2 值。

这种方法结合网格搜索和线性回归,既能全局搜索非线性部分的参数,又能利用线性回归精细调整后两项的权重。


6. sparse_mask 函数

def sparse_mask(in_dim, out_dim):

作用

  • 构造一个用于网络连接的稀疏掩码矩阵。

  • 过程

    • 为输入和输出节点分配“坐标”(按均匀分布确定中心位置)。

    • 计算输入与输出节点坐标之间的距离矩阵。

    • 分别为每个输入和输出找到最近的对方节点,并将这些连接记录下来。

    • 根据连接记录构造一个二值矩阵,表示哪些输入输出节点之间存在连接。

这种稀疏连接方式有助于设计轻量级网络结构或在剪枝时保持某种拓扑结构。


7. add_symbolic 函数

def add_symbolic(name, fun, c=1, fun_singularity=None):

作用

  • 动态地将新的符号函数加入 SYMBOLIC_LIB 库中。

  • 过程

    • 使用 exec 动态创建一个 sympy 函数,名称为给定的 name

    • 如果没有提供奇异性保护版本,则默认使用 fun

    • 更新 SYMBOLIC_LIB 字典,将新函数与其 torch 实现、sympy 实现、复杂度因子 c 及奇异性函数记录下来。

这种设计允许用户扩展符号库,以便在构建模型时使用更多自定义的函数。


8. ex_round 函数

def ex_round(ex1, n_digit):

作用

  • 对给定的 sympy 表达式 ex1 中的浮点数进行四舍五入处理,使得表达式中所有数字均保留 n_digit 位小数。

  • 过程

    • 遍历表达式的所有子表达式,对于 sympy.Float 类型进行 round 替换。

有助于在输出符号表达式时保持表达式的简洁性和可读性。


9. augment_input 函数

def augment_input(orig_vars, aux_vars, x):

作用

  • 对原始输入数据 x 增加辅助特征。辅助特征通过符号表达式 aux_vars 计算得到。

  • 过程

    • 判断输入 x 是否为 torch.Tensor(或者为包含训练/测试数据的字典)。

    • 利用 lambdify 将 sympy 表达式转换为数值函数,然后对每个辅助变量计算对应的数值,并将结果拼接到原始输入的后面。

这种输入扩增方法可用于提高模型表达能力,尤其在符号回归中,通过辅助变量的引入有助于捕捉更复杂的关系。


10. batch_jacobian 和 batch_hessian 函数

batch_jacobian
def batch_jacobian(func, x, create_graph=False, mode='scalar'):
  • 作用:利用 PyTorch 的自动求导功能,计算给定函数 func 对输入 x 的雅可比矩阵。

  • 实现

    • 先将函数求和(或按要求变换),再调用 torch.autograd.functional.jacobian 计算雅可比。

    • 参数 mode 决定输出形式是标量模式还是向量模式。

batch_hessian
def batch_hessian(model, x, create_graph=False):
  • 作用:计算函数关于输入的 Hessian 矩阵(即二阶导数)。

  • 实现

    • 先计算雅可比,再对雅可比求导得到 Hessian,最终调整维度返回。

这两个函数在模型敏感性分析、解释性方法和二阶优化中都有重要应用。


11. create_dataset_from_data 函数

def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'):

作用

  • 从已有数据(输入和标签)中,根据给定的训练比例划分出训练集和测试集,并将数据移动到指定设备上。

  • 实现

    • 利用 numpy 随机采样选取训练样本,其余作为测试样本。

适用于当用户已有数据而非使用合成公式生成数据时的场景。


12. get_derivative 函数

def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1., lamb_entropy=0.):

作用

  • 计算模型在给定数据下的损失函数对模型参数的导数(雅可比或 Hessian),用于分析模型对参数变化的敏感度或用于二阶优化。

关键步骤

  1. 构建映射(get_mapping)

    • 利用正则表达式解析模型状态字典中的参数名称,并构造字符串形式的访问路径,用于后续动态地将参数加载到模型中。

  2. 模型参数展平

    • 通过 model2param 函数将所有参数拼接成一个一维张量,方便求导。

  3. 构造可微分的损失函数

    • 根据 loss_mode(预测误差、正则项或二者的组合)构造一个函数 param2loss_fun,该函数输入为展平参数向量,然后利用动态加载方式将其写回模型,再计算损失值。

    • 这里使用 exec 动态更新模型参数,注意该步骤可能存在非微分性,但目的是为了在外部求导时将参数变化与损失函数关联起来。

  4. 求导

    • 根据用户选择的 derivative('jacobian' 或 'hessian'),调用前面定义的 batch_jacobian 或 batch_hessian 进行计算,并返回结果。

这种方法能够得到损失函数对参数的敏感性信息,对于模型解释、调优以及正则化分析具有重要意义。


13. model2param 函数

def model2param(model):
  • 作用:将模型中所有参数展平(flatten)成一个一维向量,便于在 get_derivative 中进行梯度计算或二阶导数计算。

  • 实现

    • 遍历模型所有参数,并用 torch.cat 拼接成一个大张量。

        B. 完整代码

import numpy as np
import torch
from sklearn.linear_model import LinearRegression
import sympy
import yaml
from sympy.utilities.lambdify import lambdify
import re

# sigmoid = sympy.Function('sigmoid')
# name: (torch implementation, sympy implementation)

# singularity protection functions
f_inv = lambda x, y_th: ((x_th := 1/y_th), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x) * (torch.abs(x) >= x_th))
f_inv2 = lambda x, y_th: ((x_th := 1/y_th**(1/2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**2) * (torch.abs(x) >= x_th))
f_inv3 = lambda x, y_th: ((x_th := 1/y_th**(1/3)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**3) * (torch.abs(x) >= x_th))
f_inv4 = lambda x, y_th: ((x_th := 1/y_th**(1/4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**4) * (torch.abs(x) >= x_th))
f_inv5 = lambda x, y_th: ((x_th := 1/y_th**(1/5)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**5) * (torch.abs(x) >= x_th))
f_sqrt = lambda x, y_th: ((x_th := 1/y_th**2), x_th/y_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(torch.sqrt(torch.abs(x))*torch.sign(x)) * (torch.abs(x) >= x_th))
f_power1d5 = lambda x, y_th: torch.abs(x)**1.5
f_invsqrt = lambda x, y_th: ((x_th := 1/y_th**2), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/torch.sqrt(torch.abs(x))) * (torch.abs(x) >= x_th))
f_log = lambda x, y_th: ((x_th := torch.e**(-y_th)), - y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (torch.abs(x) >= x_th))
f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi/2-torch.arctan(y_th)), - y_th/delta * (clip - torch.pi/2) * (torch.abs(clip - torch.pi/2) < delta) + torch.nan_to_num(torch.tan(clip)) * (torch.abs(clip - torch.pi/2) >= delta))
f_arctanh = lambda x, y_th: ((delta := 1-torch.tanh(y_th) + 1e-4), y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (torch.abs(x) <= 1 - delta))
f_arcsin = lambda x, y_th: ((), torch.pi/2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1))
f_arccos = lambda x, y_th: ((), torch.pi/2 * (1-torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1))
f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th))

SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
                 'x^2': (lambda x: x**2, lambda x: x**2, 2, lambda x, y_th: ((), x**2)),
                 'x^3': (lambda x: x**3, lambda x: x**3, 3, lambda x, y_th: ((), x**3)),
                 'x^4': (lambda x: x**4, lambda x: x**4, 3, lambda x, y_th: ((), x**4)),
                 'x^5': (lambda x: x**5, lambda x: x**5, 3, lambda x, y_th: ((), x**5)),
                 '1/x': (lambda x: 1/x, lambda x: 1/x, 2, f_inv),
                 '1/x^2': (lambda x: 1/x**2, lambda x: 1/x**2, 2, f_inv2),
                 '1/x^3': (lambda x: 1/x**3, lambda x: 1/x**3, 3, f_inv3),
                 '1/x^4': (lambda x: 1/x**4, lambda x: 1/x**4, 4, f_inv4),
                 '1/x^5': (lambda x: 1/x**5, lambda x: 1/x**5, 5, f_inv5),
                 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
                 'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
                 'x^1.5': (lambda x: torch.sqrt(x)**3, lambda x: sympy.sqrt(x)**3, 4, f_power1d5),
                 '1/sqrt(x)': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt),
                 '1/x^0.5': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt),
                 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp),
                 'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log),
                 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))),
                 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))),
                 'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))),
                 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan),
                 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))),
                 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))),
                 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin),
                 'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos),
                 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))),
                 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh),
                 '0': (lambda x: x*0, lambda x: x*0, 0, lambda x, y_th: ((), x*0)),
                 'gaussian': (lambda x: torch.exp(-x**2), lambda x: sympy.exp(-x**2), 3, lambda x, y_th: ((), torch.exp(-x**2))),
                 #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
                 #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
                 #'relu': (lambda x: torch.relu(x), relu),
}

def create_dataset(f, 
                   n_var=2, 
                   f_mode = 'col',
                   ranges = [-1,1],
                   train_num=1000, 
                   test_num=1000,
                   normalize_input=False,
                   normalize_label=False,
                   device='cpu',
                   seed=0):
    '''
    create dataset
    
    Args:
    -----
        f : function
            the symbolic formula used to create the synthetic dataset
        ranges : list or np.array; shape (2,) or (n_var, 2)
            the range of input variables. Default: [-1,1].
        train_num : int
            the number of training samples. Default: 1000.
        test_num : int
            the number of test samples. Default: 1000.
        normalize_input : bool
            If True, apply normalization to inputs. Default: False.
        normalize_label : bool
            If True, apply normalization to labels. Default: False.
        device : str
            device. Default: 'cpu'.
        seed : int
            random seed. Default: 0.
        
    Returns:
    --------
        dataset : dic
            Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
                        dataset['test_input'], dataset['test_label']
         
    Example
    -------
    >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
    >>> dataset = create_dataset(f, n_var=2, train_num=100)
    >>> dataset['train_input'].shape
    torch.Size([100, 2])
    '''

    np.random.seed(seed)
    torch.manual_seed(seed)

    if len(np.array(ranges).shape) == 1:
        ranges = np.array(ranges * n_var).reshape(n_var,2)
    else:
        ranges = np.array(ranges)
        
    
    train_input = torch.zeros(train_num, n_var)
    test_input = torch.zeros(test_num, n_var)
    for i in range(n_var):
        train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
        test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
                
    if f_mode == 'col':
        train_label = f(train_input)
        test_label = f(test_input)
    elif f_mode == 'row':
        train_label = f(train_input.T)
        test_label = f(test_input.T)
    else:
        print(f'f_mode {f_mode} not recognized')
        
    # if has only 1 dimension
    if len(train_label.shape) == 1:
        train_label = train_label.unsqueeze(dim=1)
        test_label = test_label.unsqueeze(dim=1)
        
    def normalize(data, mean, std):
            return (data-mean)/std
            
    if normalize_input == True:
        mean_input = torch.mean(train_input, dim=0, keepdim=True)
        std_input = torch.std(train_input, dim=0, keepdim=True)
        train_input = normalize(train_input, mean_input, std_input)
        test_input = normalize(test_input, mean_input, std_input)
        
    if normalize_label == True:
        mean_label = torch.mean(train_label, dim=0, keepdim=True)
        std_label = torch.std(train_label, dim=0, keepdim=True)
        train_label = normalize(train_label, mean_label, std_label)
        test_label = normalize(test_label, mean_label, std_label)

    dataset = {}
    dataset['train_input'] = train_input.to(device)
    dataset['test_input'] = test_input.to(device)

    dataset['train_label'] = train_label.to(device)
    dataset['test_label'] = test_label.to(device)

    return dataset



def fit_params(x, y, fun, a_range=(-10,10), b_range=(-10,10), grid_number=101, iteration=3, verbose=True, device='cpu'):
    '''
    fit a, b, c, d such that
    
    .. math::
        |y-(cf(ax+b)+d)|^2
        
    is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
    
    Args:
    -----
        x : 1D array
            x values
        y : 1D array
            y values
        fun : function
            symbolic function
        a_range : tuple
            sweeping range of a
        b_range : tuple
            sweeping range of b
        grid_num : int
            number of steps along a and b
        iteration : int
            number of zooming in
        verbose : bool
            print extra information if True
        device : str
            device
        
    Returns:
    --------
        a_best : float
            best fitted a
        b_best : float
            best fitted b
        c_best : float
            best fitted c
        d_best : float
            best fitted d
        r2_best : float
            best r2 (coefficient of determination)
    
    Example
    -------
    >>> num = 100
    >>> x = torch.linspace(-1,1,steps=num)
    >>> noises = torch.normal(0,1,(num,)) * 0.02
    >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
    >>> fit_params(x, y, torch.sin)
    r2 is 0.9999727010726929
    (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
    '''
    # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array.
    # sweep a and b, choose the best fitted model   
    for _ in range(iteration):
        a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device)
        b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device)
        a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij')
        post_fun = fun(a_grid[None,:,:] * x[:,None,None] + b_grid[None,:,:])
        x_mean = torch.mean(post_fun, dim=[0], keepdim=True)
        y_mean = torch.mean(y, dim=[0], keepdim=True)
        numerator = torch.sum((post_fun - x_mean)*(y-y_mean)[:,None,None], dim=0)**2
        denominator = torch.sum((post_fun - x_mean)**2, dim=0)*torch.sum((y - y_mean)[:,None,None]**2, dim=0)
        r2 = numerator/(denominator+1e-4)
        r2 = torch.nan_to_num(r2)
        
        
        best_id = torch.argmax(r2)
        a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number
        
        
        if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1:
            if _ == 0 and verbose==True:
                print('Best value at boundary.')
            if a_id == 0:
                a_range = [a_[0], a_[1]]
            if a_id == grid_number - 1:
                a_range = [a_[-2], a_[-1]]
            if b_id == 0:
                b_range = [b_[0], b_[1]]
            if b_id == grid_number - 1:
                b_range = [b_[-2], b_[-1]]
            
        else:
            a_range = [a_[a_id-1], a_[a_id+1]]
            b_range = [b_[b_id-1], b_[b_id+1]]
            
    a_best = a_[a_id]
    b_best = b_[b_id]
    post_fun = fun(a_best * x + b_best)
    r2_best = r2[a_id, b_id]
    
    if verbose == True:
        print(f"r2 is {r2_best}")
        if r2_best < 0.9:
            print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.')

    post_fun = torch.nan_to_num(post_fun)
    reg = LinearRegression().fit(post_fun[:,None].detach().cpu().numpy(), y.detach().cpu().numpy())
    c_best = torch.from_numpy(reg.coef_)[0].to(device)
    d_best = torch.from_numpy(np.array(reg.intercept_)).to(device)
    return torch.stack([a_best, b_best, c_best, d_best]), r2_best


def sparse_mask(in_dim, out_dim):
    '''
    get sparse mask
    '''
    in_coord = torch.arange(in_dim) * 1/in_dim + 1/(2*in_dim)
    out_coord = torch.arange(out_dim) * 1/out_dim + 1/(2*out_dim)

    dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:])
    in_nearest = torch.argmin(dist_mat, dim=0)
    in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1,0)
    out_nearest = torch.argmin(dist_mat, dim=1)
    out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1,0)
    all_connection = torch.cat([in_connection, out_connection], dim=0)
    mask = torch.zeros(in_dim, out_dim)
    mask[all_connection[:,0], all_connection[:,1]] = 1.
    
    return mask


def add_symbolic(name, fun, c=1, fun_singularity=None):
    '''
    add a symbolic function to library
    
    Args:
    -----
        name : str
            name of the function
        fun : fun
            torch function or lambda function
    
    Returns:
    --------
        None
    
    Example
    -------
    >>> print(SYMBOLIC_LIB['Bessel'])
    KeyError: 'Bessel'
    >>> add_symbolic('Bessel', torch.special.bessel_j0)
    >>> print(SYMBOLIC_LIB['Bessel'])
    (<built-in function special_bessel_j0>, Bessel)
    '''
    exec(f"globals()['{name}'] = sympy.Function('{name}')")
    if fun_singularity==None:
        fun_singularity = fun
    SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity)
    
  
def ex_round(ex1, n_digit):
    '''
    rounding the numbers in an expression to certain floating points
    
    Args:
    -----
        ex1 : sympy expression
        n_digit : int
        
    Returns:
    --------
        ex2 : sympy expression
    
    Example
    -------
    >>> from kan.utils import *
    >>> from sympy import *
    >>> input_vars = a, b = symbols('a b')
    >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
    >>> ex_round(expression, 2)
    '''
    ex2 = ex1
    for a in sympy.preorder_traversal(ex1):
        if isinstance(a, sympy.Float):
            ex2 = ex2.subs(a, round(a, n_digit))
    return ex2


def augment_input(orig_vars, aux_vars, x):
    '''
    augment inputs
    
    Args:
    -----
        orig_vars : list of sympy symbols
        aux_vars : list of auxiliary symbols
        x : inputs
        
    Returns:
    --------
        augmented inputs
    
    Example
    -------
    >>> from kan.utils import *
    >>> from sympy import *
    >>> orig_vars = a, b = symbols('a b')
    >>> aux_vars = [a + b, a * b]
    >>> x = torch.rand(100, 2)
    >>> augment_input(orig_vars, aux_vars, x).shape
    '''
    # if x is a tensor
    if isinstance(x, torch.Tensor):
        
        aux_values = torch.tensor([]).to(x.device)

        for aux_var in aux_vars:
            func = lambdify(orig_vars, aux_var,'numpy') # returns a numpy-ready function
            aux_value = torch.from_numpy(func(*[x[:,[i]].numpy() for i in range(len(orig_vars))]))
            aux_values = torch.cat([aux_values, aux_value], dim=1)
            
        x = torch.cat([aux_values, x], dim=1)

    # if x is a dataset
    elif isinstance(x, dict):
        x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input'])
        x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input'])
        
    return x


def batch_jacobian(func, x, create_graph=False, mode='scalar'):
    '''
    jacobian
    
    Args:
    -----
        func : function or model
        x : inputs
        create_graph : bool
        
    Returns:
    --------
        jacobian
    
    Example
    -------
    >>> from kan.utils import batch_jacobian
    >>> x = torch.normal(0,1,size=(100,2))
    >>> model = lambda x: x[:,[0]] + x[:,[1]]
    >>> batch_jacobian(model, x)
    '''
    # x in shape (Batch, Length)
    def _func_sum(x):
        return func(x).sum(dim=0)
    if mode == 'scalar':
        return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
    elif mode == 'vector':
        return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

def batch_hessian(model, x, create_graph=False):
    '''
    hessian
    
    Args:
    -----
        func : function or model
        x : inputs
        create_graph : bool
        
    Returns:
    --------
        jacobian
    
    Example
    -------
    >>> from kan.utils import batch_hessian
    >>> x = torch.normal(0,1,size=(100,2))
    >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2
    >>> batch_hessian(model, x)
    '''
    # x in shape (Batch, Length)
    jac = lambda x: batch_jacobian(model, x, create_graph=True)
    def _jac_sum(x):
        return jac(x).sum(dim=0)
    return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)


def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
    '''
    create dataset from data
    
    Args:
    -----
        inputs : 2D torch.float
        labels : 2D torch.float
        train_ratio : float
            the ratio of training fraction
        device : str
        
    Returns:
    --------
        dataset (dictionary)
    
    Example
    -------
    >>> from kan.utils import create_dataset_from_data
    >>> x = torch.normal(0,1,size=(100,2))
    >>> y = torch.normal(0,1,size=(100,1))
    >>> dataset = create_dataset_from_data(x, y)
    >>> dataset['train_input'].shape
    '''
    num = inputs.shape[0]
    train_id = np.random.choice(num, int(num*train_ratio), replace=False)
    test_id = list(set(np.arange(num)) - set(train_id))
    dataset = {}
    dataset['train_input'] = inputs[train_id].detach().to(device)
    dataset['test_input'] = inputs[test_id].detach().to(device)
    dataset['train_label'] = labels[train_id].detach().to(device)
    dataset['test_label'] = labels[test_id].detach().to(device)
    
    return dataset


def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1., lamb_entropy=0.):
    '''
    compute the jacobian/hessian of loss wrt to model parameters
    
    Args:
    -----
        inputs : 2D torch.float
        labels : 2D torch.float
        derivative : str
            'jacobian' or 'hessian'
        device : str
        
    Returns:
    --------
        jacobian or hessian
    '''
    def get_mapping(model):

        mapping = {}
        name = 'model1'

        keys = list(model.state_dict().keys())
        for key in keys:

            y = re.findall(".[0-9]+", key)
            if len(y) > 0:
                y = y[0][1:]
                x = re.split(".[0-9]+", key)
                mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1]


            y = re.findall("_[0-9]+", key)
            if len(y) > 0:
                y = y[0][1:]
                x = re.split(".[0-9]+", key)
                mapping[key] = name + '.' + x[0] + '[' + y + ']'

        return mapping

    
    #model1 = copy.deepcopy(model)
    model1 = model.copy()
    mapping = get_mapping(model)
   
    # collect keys and shapes
    keys = list(model.state_dict().keys())
    shapes = []

    for params in model.parameters():
        shapes.append(params.shape)


    # turn a flattened vector to model params
    def param2statedict(p, keys, shapes):

        new_state_dict = {}

        start = 0
        n_group = len(keys)
        for i in range(n_group):
            shape = shapes[i]
            n_params = torch.prod(torch.tensor(shape))
            new_state_dict[keys[i]] = p[start:start+n_params].reshape(shape)
            start += n_params

        return new_state_dict
    
    def differentiable_load_state_dict(mapping, state_dict, model1):

        for key in keys:
            if mapping[key][-1] != ']':
                exec(f"del {mapping[key]}")
            exec(f"{mapping[key]} = state_dict[key]")
            

    # input: p, output: output
    def get_param2loss_fun(inputs, labels):

        def param2loss_fun(p):

            p = p[0]
            state_dict = param2statedict(p, keys, shapes)
            # this step is non-differentiable
            #model.load_state_dict(state_dict)
            differentiable_load_state_dict(mapping, state_dict, model1)
            if loss_mode == 'pred':
                pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True)
                loss = pred_loss
            elif loss_mode == 'reg':
                reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1)
                loss = reg_loss
            elif loss_mode == 'all':
                pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True)
                reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1)
                loss = pred_loss + lamb * reg_loss
            return loss

        return param2loss_fun
    
    fun = get_param2loss_fun(inputs, labels)    
    p = model2param(model)[None,:]
    if derivative == 'hessian':
        result = batch_hessian(fun, p)
    elif derivative == 'jacobian':
        result = batch_jacobian(fun, p)
    return result

def model2param(model):
    '''
    turn model parameters into a flattened vector
    '''
    p = torch.tensor([]).to(model.device)
    for params in model.parameters():
        p = torch.cat([p, params.reshape(-1,)], dim=0)
    return p


四、总结与思考

        KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。

        1. KAN网络架构

  • 关键设计可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。

  • 示例结构输入层 → 隐藏层:每个输入节点通过单变量函数\phi_{q,i} \left( x_{i} \right) 连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数\psi_{q}组合得到输出。

        2. 优势与特点

  • 高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。

  • 可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。

  • 灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。

        3. 挑战与局限

  • 计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。

  • 泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。

  • 训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。

        4. 应用场景

  • 科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。

  • 可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。

  • 符号回归:从数据中自动发现数学公式(如物理定律)。

        5. 与传统MLP的对比

        6. 研究进展

  • 近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。

  • 开源实现:已有PyTorch等框架的初步实现。


【作者声明】

        本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。


 【关注我们】

        如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!

猜你喜欢

转载自blog.csdn.net/2303_77200324/article/details/146524771
今日推荐