【KAN】KAN神经网络学习训练营(6)——KANLayer.py

一、引言

        KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。


二、技术与原理简介

        1.Kolmogorov-Arnold 表示定理

         Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑ff:[0,1]^{^{n}}\rightarrow \mathbb{R}

f \left( x \right)=f \left( x_{1}, \cdots,x_{n} \right)= \sum_{q=1}^{2n+1} \Phi_{q} \left( \sum_{p=1}^{n} \phi_{q,p} \left( x_{p} \right) \right)

        其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。\boldsymbol{\phi_{q,p}:[0,1]\to\mathbb{R}}\boldsymbol{\Phi_{q}:\mathbb{R}\to\mathbb{R}(2n+1)}

        2.Kolmogorov-Arnold 网络 (KAN)

        Kolmogorov-Arnold 表示可以写成矩阵形式

f(x)=\mathbf{\Phi_{out}}\mathsf{o}\mathbf{\Phi_{in}}\mathsf{o}{}x

其中

\mathbf{\Phi}_{\mathrm{in}}=\begin{pmatrix}\phi_{1,1}(\cdot)&\cdots&\phi_{1,n }(\cdot)\\ \vdots&&\vdots\\ \phi_{2n+1,1}(\cdot)&\cdots&\phi_{2n+1,n}(\cdot)\end{pmatrix}

\quad\mathbf{ \Phi}_{\mathrm{out}}=\left(\Phi_{1}(\cdot)\quad\cdots\quad\Phi_{2n+1}(\cdot)\right)

        我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:\mathbf{\Phi_{in}} \mathbf{\Phi_{out}} \mathbf{\Phi_{n_{in}n_{out}}}

其中\boldsymbol{n_{\text{in}}=n,n_{\text{out}}=2n+1\Phi_{\text{out}}n_{\text{in}}=2n+1,n_{\text{out}}=1}

        定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是Ll^{th} \Phi_{l} \left( n_{l+1},n_{l} \right)

\mathbf{KAN(x)}=\mathbf{\Phi_{L-1}}\circ\cdots\circ\mathbf{\Phi_{1}}\circ \mathbf{\Phi_{0}}\circ\mathbf{x}

        相反,多层感知器由线性层和非线错:\mathbf{W}_{l^{\sigma}}

\text{MLP}(\mathbf{x})=\mathbf{W}_{\textit{L-1}}\circ\sigma\circ\cdots\circ \mathbf{W}_{1}\circ\sigma\circ\mathbf{W}_{0}\circ\mathbf{x}

        KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。


三、代码详解

        代码实现了一个基于样条函数的自适应非线性层,通过动态更新网格和样条系数,将输入变量进行局部多项式拟合,并结合残差函数增强模型表达能力。各个辅助方法(如网格更新、子集提取和神经元交换)为模型的灵活性和可解释性

        A. 代码详解

        1. 类定义与类属性说明

class KANLayer(nn.Module):
    """
    KANLayer 类实现了一层基于 B-样条函数的映射。该层将输入变量通过样条函数进行非线性变换,并叠加一个残差函数以增强表达能力。
    
    属性说明:
        - in_dim: 输入维度
        - out_dim: 输出维度
        - num: 网格区间数,用于确定样条节点数
        - k: 样条多项式的阶数
        - noise_scale: 初始化时注入噪声的尺度
        - coef: B-样条基函数的系数,通过初始噪声和网格计算得到
        - scale_base: 残差函数 b(x) 的尺度参数,其初始值服从正态分布
        - scale_sp: 样条函数 spline(x) 的尺度参数
        - base_fun: 用于计算残差函数 b(x) 的激活函数(例如 SiLU)
        - mask: 对应激活函数的掩码,部分元素置零可实现部分激活函数失效(用于稀疏初始化或剪枝)
        - grid_eps: 用于自适应更新网格的超参数,控制均匀网格和基于样本百分位数的网格之间的插值
        - device: 指定计算设备
    """
  • 该类继承自 nn.Module,主要用于构建基于样条函数的非线性层,其核心在于通过网格(grid)和样条系数(coef)将输入进行变换,同时结合残差部分增强灵活性。

        2. 构造函数 __init__

def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
  • 参数设定:设置输入/输出维度、网格区间数、样条阶数以及噪声尺度等超参数,确保初始化时样条和残差部分具有适当的尺度。
  • 网格初始化:利用 torch.linspace 在给定范围内生成初步的网格,然后通过 extend_grid(来自 .spline 模块)扩展网格,使其适用于边界处的样条计算。
  • 系数计算:使用 curve2coef 函数将网格与初始噪声结合,生成 B-样条基函数的系数。
  • 稀疏初始化:根据 sparse_init 参数,选择使用预设的稀疏掩码(通过 sparse_mask)或全部激活。
  • 残差与样条尺度:分别定义 scale_basescale_sp 参数,并设置其是否参与训练(sb_trainablesp_trainable)。
  • 设备设置:调用 self.to(device) 将所有参数移动到指定设备。

        3. 重载 to 方法

def to(self, device):
    super(KANLayer, self).to(device)
    self.device = device    
    return self
  • 扩展 nn.Moduleto 方法,使得在移动模块到指定设备时,同时更新内部记录的 device 属性,便于后续调用。

        4. 前向传播 forward

def forward(self, x):
    batch = x.shape[0]
    preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim)
        
    base = self.base_fun(x) # (batch, in_dim)
    y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
        
    postspline = y.clone().permute(0,2,1)
        
    y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y
    y = self.mask[None,:,:] * y
        
    postacts = y.clone().permute(0,2,1)
        
    y = torch.sum(y, dim=1)
    return y, preacts, postacts, postspline
  • preacts:将输入 x 沿输出维度扩展,形成每个输出对应每个输入的预激活值。
  • 残差计算:调用 base_fun 计算残差部分,对输入进行简单的非线性变换。
  • 样条计算:利用 coef2curve 依据当前网格和样条系数计算样条函数的输出,得到局部变换结果。
  • 组合与加权:利用 scale_basescale_sp 对残差和样条部分分别加权,然后乘以掩码(mask)实现部分神经元屏蔽。
  • 输出整合:对加权后的结果在输入维度上求和,得到最终输出;同时返回中间结果(preacts、postacts、postspline)便于后续调试或可视化。

        5. 网格更新 update_grid_from_samples

def update_grid_from_samples(self, x, mode='sample'):
    batch = x.shape[0]
    x_pos = torch.sort(x, dim=0)[0]
    y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
    num_interval = self.grid.shape[1] - 1 - 2*self.k
    
    def get_grid(num_interval):
        ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
        grid_adaptive = x_pos[ids, :].permute(1,0)
        h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
        grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        return grid
    
    grid = get_grid(num_interval)
    
    if mode == 'grid':
        sample_grid = get_grid(2*num_interval)
        x_pos = sample_grid.permute(1,0)
        y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
    
    self.grid.data = extend_grid(grid, k_extend=self.k)
    self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
  • 目标:根据输入样本的分布动态更新网格,使得样条函数更好地适应数据的局部分布。
  • 自适应网格:先对输入样本排序,再根据样本分位数构造自适应网格,与均匀网格进行线性插值,得到最终网格。
  • 参数更新:利用新网格和对应的函数值重新计算样条系数,从而更新模型的 grid 与 coef 参数。

        6. 从父层初始化网格 initialize_grid_from_parent

def initialize_grid_from_parent(self, parent, x, mode='sample'):
    batch = x.shape[0]
    x_pos = torch.sort(x, dim=0)[0]
    y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
    num_interval = self.grid.shape[1] - 1 - 2*self.k
    
    def get_grid(num_interval):
        ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
        grid_adaptive = x_pos[ids, :].permute(1,0)
        h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
        grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        return grid
    
    grid = get_grid(num_interval)
    
    if mode == 'grid':
        sample_grid = get_grid(2*num_interval)
        x_pos = sample_grid.permute(1,0)
        y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
    
    grid = extend_grid(grid, k_extend=self.k)
    self.grid.data = grid
    self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
  • 用途:利用父层(通常网格较粗)的参数和样本数据,为当前层生成更细的网格及其对应系数,常用于多层结构的初始化。
  • 实现方式:与 update_grid_from_samples 类似,但使用父层的 grid 和 coef 计算评估值,从而获得更合理的初始值。

        7. 获取子集 get_subset

def get_subset(self, in_id, out_id):
    spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun)
    spb.grid.data = self.grid[in_id]
    spb.coef.data = self.coef[in_id][:,out_id]
    spb.scale_base.data = self.scale_base[in_id][:,out_id]
    spb.scale_sp.data = self.scale_sp[in_id][:,out_id]
    spb.mask.data = self.mask[in_id][:,out_id]

    spb.in_dim = len(in_id)
    spb.out_dim = len(out_id)
    return spb
  • 功能:从当前较大模型中选取部分输入和输出神经元,生成一个较小的 KANLayer。此方法常用于模型剪枝或子结构分析。
  • 实现细节:通过索引操作,提取原有层中对应部分的 grid、coef、尺度参数和 mask,并构造新的 KANLayer 对象。

        8. 神经元交换 swap

def swap(self, i1, i2, mode='in'):
    with torch.no_grad():
        def swap_(data, i1, i2, mode='in'):
            if mode == 'in':
                data[i1], data[i2] = data[i2].clone(), data[i1].clone()
            elif mode == 'out':
                data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()

        if mode == 'in':
            swap_(self.grid.data, i1, i2, mode='in')
        swap_(self.coef.data, i1, i2, mode=mode)
        swap_(self.scale_base.data, i1, i2, mode=mode)
        swap_(self.scale_sp.data, i1, i2, mode=mode)
        swap_(self.mask.data, i1, i2, mode=mode)
  • 目的:交换输入或输出神经元的顺序。该操作在需要重排层内参数或进行某种对称性处理时非常有用。
  • 实现:利用内部辅助函数 swap_,对 grid、coef、scale_base、scale_sp 和 mask 等数据进行交换。使用 torch.no_grad() 确保该过程不参与梯度计算。

        B. 完整代码

import torch
import torch.nn as nn
import numpy as np
from .spline import *
from .utils import sparse_mask


class KANLayer(nn.Module):
    """
    KANLayer class
    

    Attributes:
    -----------
        in_dim: int
            input dimension
        out_dim: int
            output dimension
        num: int
            the number of grid intervals
        k: int
            the piecewise polynomial order of splines
        noise_scale: float
            spline scale at initialization
        coef: 2D torch.tensor
            coefficients of B-spline bases
        scale_base_mu: float
            magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu
        scale_base_sigma: float
            magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma
        scale_sp: float
            mangitude of the spline function spline(x)
        base_fun: fun
            residual function b(x)
        mask: 1D torch.float
            mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.
        grid_eps: float in [0,1]
            a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
            the id of activation functions that are locked
        device: str
            device
    """

    def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
        ''''
        initialize a KANLayer
        
        Args:
        -----
            in_dim : int
                input dimension. Default: 2.
            out_dim : int
                output dimension. Default: 3.
            num : int
                the number of grid intervals = G. Default: 5.
            k : int
                the order of piecewise polynomial. Default: 3.
            noise_scale : float
                the scale of noise injected at initialization. Default: 0.1.
            scale_base_mu : float
                the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
            scale_base_sigma : float
                the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
            scale_sp : float
                the scale of the base function spline(x).
            base_fun : function
                residual function b(x). Default: torch.nn.SiLU()
            grid_eps : float
                When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
            grid_range : list/np.array of shape (2,)
                setting the range of grids. Default: [-1,1].
            sp_trainable : bool
                If true, scale_sp is trainable
            sb_trainable : bool
                If true, scale_base is trainable
            device : str
                device
            sparse_init : bool
                if sparse_init = True, sparse initialization is applied.
            
        Returns:
        --------
            self
            
        Example
        -------
        >>> from kan.KANLayer import *
        >>> model = KANLayer(in_dim=3, out_dim=5)
        >>> (model.in_dim, model.out_dim)
        '''
        super(KANLayer, self).__init__()
        # size 
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.num = num
        self.k = k

        grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)
        grid = extend_grid(grid, k_extend=k)
        self.grid = torch.nn.Parameter(grid).requires_grad_(False)
        noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num

        self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
        
        if sparse_init:
            self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False)
        else:
            self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False)
        
        self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
                         scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
        self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * self.mask).requires_grad_(sp_trainable)  # make scale trainable
        self.base_fun = base_fun

        
        self.grid_eps = grid_eps
        
        self.to(device)
        
    def to(self, device):
        super(KANLayer, self).to(device)
        self.device = device    
        return self

    def forward(self, x):
        '''
        KANLayer forward given input x
        
        Args:
        -----
            x : 2D torch.float
                inputs, shape (number of samples, input dimension)
            
        Returns:
        --------
            y : 2D torch.float
                outputs, shape (number of samples, output dimension)
            preacts : 3D torch.float
                fan out x into activations, shape (number of sampels, output dimension, input dimension)
            postacts : 3D torch.float
                the outputs of activation functions with preacts as inputs
            postspline : 3D torch.float
                the outputs of spline functions with preacts as inputs
        
        Example
        -------
        >>> from kan.KANLayer import *
        >>> model = KANLayer(in_dim=3, out_dim=5)
        >>> x = torch.normal(0,1,size=(100,3))
        >>> y, preacts, postacts, postspline = model(x)
        >>> y.shape, preacts.shape, postacts.shape, postspline.shape
        '''
        batch = x.shape[0]
        preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim)
            
        base = self.base_fun(x) # (batch, in_dim)
        y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
        
        postspline = y.clone().permute(0,2,1)
            
        y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y
        y = self.mask[None,:,:] * y
        
        postacts = y.clone().permute(0,2,1)
            
        y = torch.sum(y, dim=1)
        return y, preacts, postacts, postspline

    def update_grid_from_samples(self, x, mode='sample'):
        '''
        update grid from samples
        
        Args:
        -----
            x : 2D torch.float
                inputs, shape (number of samples, input dimension)
            
        Returns:
        --------
            None
        
        Example
        -------
        >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
        >>> print(model.grid.data)
        >>> x = torch.linspace(-3,3,steps=100)[:,None]
        >>> model.update_grid_from_samples(x)
        >>> print(model.grid.data)
        '''
        
        batch = x.shape[0]
        #x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
        x_pos = torch.sort(x, dim=0)[0]
        y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
        num_interval = self.grid.shape[1] - 1 - 2*self.k
        
        def get_grid(num_interval):
            ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
            grid_adaptive = x_pos[ids, :].permute(1,0)
            h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
            grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
            grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
            return grid
        
        grid = get_grid(num_interval)
        
        if mode == 'grid':
            sample_grid = get_grid(2*num_interval)
            x_pos = sample_grid.permute(1,0)
            y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
        
        self.grid.data = extend_grid(grid, k_extend=self.k)
        self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

    def initialize_grid_from_parent(self, parent, x, mode='sample'):
        '''
        update grid from a parent KANLayer & samples
        
        Args:
        -----
            parent : KANLayer
                a parent KANLayer (whose grid is usually coarser than the current model)
            x : 2D torch.float
                inputs, shape (number of samples, input dimension)
            
        Returns:
        --------
            None
          
        Example
        -------
        >>> batch = 100
        >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
        >>> print(parent_model.grid.data)
        >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3)
        >>> x = torch.normal(0,1,size=(batch, 1))
        >>> model.initialize_grid_from_parent(parent_model, x)
        >>> print(model.grid.data)
        '''
        
        batch = x.shape[0]
        
        x_pos = torch.sort(x, dim=0)[0]
        y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
        num_interval = self.grid.shape[1] - 1 - 2*self.k
        
        def get_grid(num_interval):
            ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
            grid_adaptive = x_pos[ids, :].permute(1,0)
            h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
            grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
            grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
            return grid
        
        grid = get_grid(num_interval)
        
        if mode == 'grid':
            sample_grid = get_grid(2*num_interval)
            x_pos = sample_grid.permute(1,0)
            y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
        
        grid = extend_grid(grid, k_extend=self.k)
        self.grid.data = grid
        self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

    def get_subset(self, in_id, out_id):
        '''
        get a smaller KANLayer from a larger KANLayer (used for pruning)
        
        Args:
        -----
            in_id : list
                id of selected input neurons
            out_id : list
                id of selected output neurons
            
        Returns:
        --------
            spb : KANLayer
            
        Example
        -------
        >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3)
        >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3])
        >>> kanlayer_small.in_dim, kanlayer_small.out_dim
        (2, 3)
        '''
        spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun)
        spb.grid.data = self.grid[in_id]
        spb.coef.data = self.coef[in_id][:,out_id]
        spb.scale_base.data = self.scale_base[in_id][:,out_id]
        spb.scale_sp.data = self.scale_sp[in_id][:,out_id]
        spb.mask.data = self.mask[in_id][:,out_id]

        spb.in_dim = len(in_id)
        spb.out_dim = len(out_id)
        return spb
    
    
    def swap(self, i1, i2, mode='in'):
        '''
        swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') 
        
        Args:
        -----
            i1 : int
            i2 : int
            mode : str
                mode = 'in' or 'out'
            
        Returns:
        --------
            None
            
        Example
        -------
        >>> from kan.KANLayer import *
        >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3)
        >>> print(model.coef)
        >>> model.swap(0,1,mode='in')
        >>> print(model.coef)
        '''
        with torch.no_grad():
            def swap_(data, i1, i2, mode='in'):
                if mode == 'in':
                    data[i1], data[i2] = data[i2].clone(), data[i1].clone()
                elif mode == 'out':
                    data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()

            if mode == 'in':
                swap_(self.grid.data, i1, i2, mode='in')
            swap_(self.coef.data, i1, i2, mode=mode)
            swap_(self.scale_base.data, i1, i2, mode=mode)
            swap_(self.scale_sp.data, i1, i2, mode=mode)
            swap_(self.mask.data, i1, i2, mode=mode)


四、总结与思考

        KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。

        1. KAN网络架构

  • 关键设计可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。

  • 示例结构输入层 → 隐藏层:每个输入节点通过单变量函数\phi_{q,i} \left( x_{i} \right) 连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数\psi_{q}组合得到输出。

        2. 优势与特点

  • 高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。

  • 可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。

  • 灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。

        3. 挑战与局限

  • 计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。

  • 泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。

  • 训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。

        4. 应用场景

  • 科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。

  • 可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。

  • 符号回归:从数据中自动发现数学公式(如物理定律)。

        5. 与传统MLP的对比

        6. 研究进展

  • 近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。

  • 开源实现:已有PyTorch等框架的初步实现。


【作者声明】

        本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。


 【关注我们】

        如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!

猜你喜欢

转载自blog.csdn.net/2303_77200324/article/details/146424099
今日推荐