【KAN】KAN神经网络学习训练营(7)——LBFGS.py

一、引言

        KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。


二、技术与原理简介

        1.Kolmogorov-Arnold 表示定理

         Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑ff:[0,1]^{^{n}}\rightarrow \mathbb{R}

f \left( x \right)=f \left( x_{1}, \cdots,x_{n} \right)= \sum_{q=1}^{2n+1} \Phi_{q} \left( \sum_{p=1}^{n} \phi_{q,p} \left( x_{p} \right) \right)

        其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。\boldsymbol{\phi_{q,p}:[0,1]\to\mathbb{R}}\boldsymbol{\Phi_{q}:\mathbb{R}\to\mathbb{R}(2n+1)}

        2.Kolmogorov-Arnold 网络 (KAN)

        Kolmogorov-Arnold 表示可以写成矩阵形式

f(x)=\mathbf{\Phi_{out}}\mathsf{o}\mathbf{\Phi_{in}}\mathsf{o}{}x

其中

\mathbf{\Phi}_{\mathrm{in}}=\begin{pmatrix}\phi_{1,1}(\cdot)&\cdots&\phi_{1,n }(\cdot)\\ \vdots&&\vdots\\ \phi_{2n+1,1}(\cdot)&\cdots&\phi_{2n+1,n}(\cdot)\end{pmatrix}

\quad\mathbf{ \Phi}_{\mathrm{out}}=\left(\Phi_{1}(\cdot)\quad\cdots\quad\Phi_{2n+1}(\cdot)\right)

        我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:\mathbf{\Phi_{in}} \mathbf{\Phi_{out}} \mathbf{\Phi_{n_{in}n_{out}}}

其中\boldsymbol{n_{\text{in}}=n,n_{\text{out}}=2n+1\Phi_{\text{out}}n_{\text{in}}=2n+1,n_{\text{out}}=1}

        定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是Ll^{th} \Phi_{l} \left( n_{l+1},n_{l} \right)

\mathbf{KAN(x)}=\mathbf{\Phi_{L-1}}\circ\cdots\circ\mathbf{\Phi_{1}}\circ \mathbf{\Phi_{0}}\circ\mathbf{x}

        相反,多层感知器由线性层和非线错:\mathbf{W}_{l^{\sigma}}

\text{MLP}(\mathbf{x})=\mathbf{W}_{\textit{L-1}}\circ\sigma\circ\cdots\circ \mathbf{W}_{1}\circ\sigma\circ\mathbf{W}_{0}\circ\mathbf{x}

        KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。


三、代码详解

        该代码实现了一个基于 LBFGS 算法的优化器,其中包含了:利用三次插值计算步长的辅助函数 _cubic_interpolate;实现强 Wolfe 线搜索策略的函数 _strong_wolfe;LBFGS 优化器类LBFGS,该类继承自 PyTorch 的 Optimizer,实现了对参数梯度的聚合、历史信息的存储与更新以及基于二阶信息(有限记忆近似 Hessian)的下降方向计算。整体思路是先利用历史信息来近似计算逆 Hessian,从而获取一个较优的下降方向,再通过线搜索调整步长,确保在每次参数更新时能够充分降低目标函数值,同时满足 Wolfe 条件。

        A. 代码详解

        1. 函数 _cubic_interpolate

def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
    # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
  • 作用:利用两点的函数值与梯度值进行三次插值,计算满足一定条件的新的步长(step size)。这是线搜索中常用的方法,用于估计沿着下降方向(d)应采取的步长,从而使得目标函数下降并满足 Wolfe 条件。
  • 主要步骤
    • 确定插值区域边界:若给定 bounds,则使用 bounds,否则根据 x1 与 x2 的大小确定插值区间。
    • 计算插值参数
      • 计算中间量 d1,结合两个点的梯度和函数值差异;
      • 判断 d1² - g1*g2 是否大于等于零,若满足则计算 d2(即平方根项),然后根据 x1 与 x2 的顺序计算出新的最优位置 min_pos
      • 最后,将该值限制在边界范围内返回;
      • 若平方根项为负,则返回区间中点。

        2. 函数 _strong_wolfe

def _strong_wolfe(obj_func,
                  x,
                  t,
                  d,
                  f,
                  g,
                  gtd,
                  c1=1e-4,
                  c2=0.9,
                  tolerance_change=1e-9,
                  max_ls=25):
  • 作用:在沿着给定方向 d 上执行强 Wolfe 线搜索。其目标是找到步长 t,使得更新后的位置满足 Armijo 条件(充分下降条件)和曲率条件。
  • 主要步骤
    • 初始计算
      • 计算方向上新的函数值 f_new 与梯度 g_new
      • 计算新的方向导数 gtd_new
    • 搜索过程
      • 进入循环不断检查:
        • 若目标函数值未达到 Armijo 条件,或者出现“反弹”(即当前函数值大于前一步或上一步的函数值),则确定一个区间(bracket),其中包含满足条件的步长;
        • 如果当前方向导数已经满足曲率条件,则直接返回当前步长;
        • 否则根据当前区间利用 _cubic_interpolate 进行插值,得到新的步长,再更新区间内的参数。
    • Zoom 阶段
      • 当找到一个区间后,通过插值不断缩小区间,直至找到满足强 Wolfe 条件的精确步长,或当步长变化极小(小于容差)时退出。
    • 返回:最终返回更新后的函数值 f_new、梯度 g_new、步长 t 和线搜索中所用的函数评估次数。

        3. 类 LBFGS

class LBFGS(Optimizer):
  • 作用:实现了 L-BFGS(Limited-memory BFGS)优化算法,属于二阶优化方法。该实现借鉴了 minFunc 的思想,通过保存过去若干次迭代的梯度和步长信息(历史记忆)来近似计算逆 Hessian,从而确定下降方向。
  • 主要特点
    • 全局状态存储:由于 L-BFGS 需要保存历史信息,所以其状态存储在 optimizer 的 state 中,仅支持单参数组。
    • 内存需求较高:存储每个参数的历史信息,因此对于大模型可能需要较多内存。
    • 仅支持单设备:所有参数必须在同一设备上。

        4. 类内部辅助函数

  • _numel

    • 计算所有参数的总元素个数,用于后续将梯度或参数展平为一维向量。
  • _gather_flat_grad

    • 遍历所有参数,将每个参数的梯度拉平后拼接为一个一维张量。如果某个参数没有梯度,则用零向量代替。
  • _add_grad

    • 将给定的更新向量按步长 step_size 加到所有参数上。该函数遍历所有参数并利用索引操作实现更新。
  • _clone_param_set_param

    • 前者用于复制当前参数状态,后者用于将参数恢复为先前保存的状态。这在进行方向评估和线搜索时尤为重要。
  • _directional_evaluate

    • 在当前参数基础上沿下降方向 d 移动步长 t,计算目标函数及其梯度,然后将参数恢复为原状态。用于在行搜索中评估新步长的效果。

        5. 核心方法 step

@torch.no_grad()
def step(self, closure):
    """Perform a single optimization step.
  • 作用:执行一次 L-BFGS 优化更新。整个 step 方法的流程如下:

    1. 初始化
      • 设定随机种子(确保可复现性);
      • 调用 closure 计算当前损失和梯度,并进行初步的最优性判断(若梯度足够小,则直接返回)。
    2. 计算下降方向
      • 第一次迭代时直接使用负梯度作为下降方向;
      • 从第二次迭代开始,通过计算当前梯度与上一次梯度之差(y = grad - prev_grad)和前一次的更新步长(s = d * t),更新存储的历史信息;
      • 根据历史信息通过两阶段循环(先从后向前,再从前向后)计算近似逆 Hessian 乘以梯度,从而得到新的下降方向。
    3. 确定步长
      • 第一次迭代时给定初始步长(与梯度大小有关),后续则通常设定为固定的学习率;
      • 计算方向导数 gtd = grad.dot(d),如果方向导数接近于 0,则认为已经收敛;
      • 若设置了线搜索函数,则调用 _strong_wolfe 进行行搜索,获得满足 Wolfe 条件的步长,并利用 _add_grad 更新参数;
      • 否则直接按固定步长更新参数,然后重新计算损失和梯度。
    4. 检查终止条件
      • 包括达到最大迭代次数、函数评估次数、梯度变化和损失变化都小于设定容差时退出。
    5. 保存状态
      • 更新状态字典,保存历史方向、步长、Hessian 对角线近似等信息,以便下次更新时使用。
    6. 返回:最终返回初始的损失值。
  • 注意

    • 该实现中使用了 torch.no_grad() 装饰器确保在参数更新过程中不计算梯度;
    • closure 必须确保在调用时开启梯度计算(因此使用了 torch.enable_grad() 包装 closure)。

        B. 完整代码

import torch
from functools import reduce
from torch.optim import Optimizer

__all__ = ['LBFGS']

def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
    # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
    # Compute bounds of interpolation area
    if bounds is not None:
        xmin_bound, xmax_bound = bounds
    else:
        xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)

    # Code for most common case: cubic interpolation of 2 points
    #   w/ function and derivative values for both
    # Solution in this case (where x2 is the farthest point):
    #   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
    #   d2 = sqrt(d1^2 - g1*g2);
    #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
    #   t_new = min(max(min_pos,xmin_bound),xmax_bound);
    d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
    d2_square = d1**2 - g1 * g2
    if d2_square >= 0:
        d2 = d2_square.sqrt()
        if x1 <= x2:
            min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
        else:
            min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
        return min(max(min_pos, xmin_bound), xmax_bound)
    else:
        return (xmin_bound + xmax_bound) / 2.


def _strong_wolfe(obj_func,
                  x,
                  t,
                  d,
                  f,
                  g,
                  gtd,
                  c1=1e-4,
                  c2=0.9,
                  tolerance_change=1e-9,
                  max_ls=25):
    # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
    d_norm = d.abs().max()
    g = g.clone(memory_format=torch.contiguous_format)
    # evaluate objective and gradient using initial step
    f_new, g_new = obj_func(x, t, d)
    ls_func_evals = 1
    gtd_new = g_new.dot(d)

    # bracket an interval containing a point satisfying the Wolfe criteria
    t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
    done = False
    ls_iter = 0
    while ls_iter < max_ls:
        # check conditions
        #print(f_prev, f_new, g_new)
        if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
            bracket = [t_prev, t]
            bracket_f = [f_prev, f_new]
            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
            bracket_gtd = [gtd_prev, gtd_new]
            break

        if abs(gtd_new) <= -c2 * gtd:
            bracket = [t]
            bracket_f = [f_new]
            bracket_g = [g_new]
            done = True
            break

        if gtd_new >= 0:
            bracket = [t_prev, t]
            bracket_f = [f_prev, f_new]
            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
            bracket_gtd = [gtd_prev, gtd_new]
            break

        # interpolate
        min_step = t + 0.01 * (t - t_prev)
        max_step = t * 10
        tmp = t
        t = _cubic_interpolate(
            t_prev,
            f_prev,
            gtd_prev,
            t,
            f_new,
            gtd_new,
            bounds=(min_step, max_step))

        # next step
        t_prev = tmp
        f_prev = f_new
        g_prev = g_new.clone(memory_format=torch.contiguous_format)
        gtd_prev = gtd_new
        f_new, g_new = obj_func(x, t, d)
        ls_func_evals += 1
        gtd_new = g_new.dot(d)
        ls_iter += 1
        

    # reached max number of iterations?
    if ls_iter == max_ls:
        bracket = [0, t]
        bracket_f = [f, f_new]
        bracket_g = [g, g_new]

    # zoom phase: we now have a point satisfying the criteria, or
    # a bracket around it. We refine the bracket until we find the
    # exact point satisfying the criteria
    insuf_progress = False
    # find high and low points in bracket
    low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
    while not done and ls_iter < max_ls:
        # line-search bracket is so small
        if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
            break

        # compute new trial value
        t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
                               bracket[1], bracket_f[1], bracket_gtd[1])

        # test that we are making sufficient progress:
        # in case `t` is so close to boundary, we mark that we are making
        # insufficient progress, and if
        #   + we have made insufficient progress in the last step, or
        #   + `t` is at one of the boundary,
        # we will move `t` to a position which is `0.1 * len(bracket)`
        # away from the nearest boundary point.
        eps = 0.1 * (max(bracket) - min(bracket))
        if min(max(bracket) - t, t - min(bracket)) < eps:
            # interpolation close to boundary
            if insuf_progress or t >= max(bracket) or t <= min(bracket):
                # evaluate at 0.1 away from boundary
                if abs(t - max(bracket)) < abs(t - min(bracket)):
                    t = max(bracket) - eps
                else:
                    t = min(bracket) + eps
                insuf_progress = False
            else:
                insuf_progress = True
        else:
            insuf_progress = False

        # Evaluate new point
        f_new, g_new = obj_func(x, t, d)
        ls_func_evals += 1
        gtd_new = g_new.dot(d)
        ls_iter += 1

        if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
            # Armijo condition not satisfied or not lower than lowest point
            bracket[high_pos] = t
            bracket_f[high_pos] = f_new
            bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
            bracket_gtd[high_pos] = gtd_new
            low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
        else:
            if abs(gtd_new) <= -c2 * gtd:
                # Wolfe conditions satisfied
                done = True
            elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
                # old low becomes new high
                bracket[high_pos] = bracket[low_pos]
                bracket_f[high_pos] = bracket_f[low_pos]
                bracket_g[high_pos] = bracket_g[low_pos]
                bracket_gtd[high_pos] = bracket_gtd[low_pos]

            # new point becomes new low
            bracket[low_pos] = t
            bracket_f[low_pos] = f_new
            bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
            bracket_gtd[low_pos] = gtd_new

    #print(bracket)
    if len(bracket) == 1:
        t = bracket[0]
        f_new = bracket_f[0]
        g_new = bracket_g[0]
    else:
        t = bracket[low_pos]
        f_new = bracket_f[low_pos]
        g_new = bracket_g[low_pos]
    return f_new, g_new, t, ls_func_evals



class LBFGS(Optimizer):
    """Implements L-BFGS algorithm.

    Heavily inspired by `minFunc
    <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.

    .. warning::
        This optimizer doesn't support per-parameter options and parameter
        groups (there can be only one).

    .. warning::
        Right now all parameters have to be on a single device. This will be
        improved in the future.

    .. note::
        This is a very memory intensive optimizer (it requires additional
        ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
        try reducing the history size, or use a different algorithm.

    Args:
        lr (float): learning rate (default: 1)
        max_iter (int): maximal number of iterations per optimization step
            (default: 20)
        max_eval (int): maximal number of function evaluations per optimization
            step (default: max_iter * 1.25).
        tolerance_grad (float): termination tolerance on first order optimality
            (default: 1e-7).
        tolerance_change (float): termination tolerance on function
            value/parameter changes (default: 1e-9).
        history_size (int): update history size (default: 100).
        line_search_fn (str): either 'strong_wolfe' or None (default: None).
    """

    def __init__(self,
                 params,
                 lr=1,
                 max_iter=20,
                 max_eval=None,
                 tolerance_grad=1e-7,
                 tolerance_change=1e-9,
                 tolerance_ys=1e-32,
                 history_size=100,
                 line_search_fn=None):
        if max_eval is None:
            max_eval = max_iter * 5 // 4
        defaults = dict(
            lr=lr,
            max_iter=max_iter,
            max_eval=max_eval,
            tolerance_grad=tolerance_grad,
            tolerance_change=tolerance_change,
            tolerance_ys=tolerance_ys,
            history_size=history_size,
            line_search_fn=line_search_fn)
        super().__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError("LBFGS doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.new(p.numel()).zero_()
            elif p.grad.is_sparse:
                view = p.grad.to_dense().view(-1)
            else:
                view = p.grad.view(-1)
            views.append(view)
        device = views[0].device
        return torch.cat(views, dim=0)

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    def _clone_param(self):
        return [p.clone(memory_format=torch.contiguous_format) for p in self._params]

    def _set_param(self, params_data):
        for p, pdata in zip(self._params, params_data):
            p.copy_(pdata)

    def _directional_evaluate(self, closure, x, t, d):
        self._add_grad(t, d)
        loss = float(closure())
        flat_grad = self._gather_flat_grad()
        self._set_param(x)
        return loss, flat_grad


    @torch.no_grad()
    def step(self, closure):
        """Perform a single optimization step.

        Args:
            closure (Callable): A closure that reevaluates the model
                and returns the loss.
        """
        
        torch.manual_seed(0)
        
        assert len(self.param_groups) == 1

        # Make sure the closure is always called with grad enabled
        closure = torch.enable_grad()(closure)

        group = self.param_groups[0]
        lr = group['lr']
        max_iter = group['max_iter']
        max_eval = group['max_eval']
        tolerance_grad = group['tolerance_grad']
        tolerance_change = group['tolerance_change']
        tolerance_ys = group['tolerance_ys']
        line_search_fn = group['line_search_fn']
        history_size = group['history_size']

        # NOTE: LBFGS has only global state, but we register it as state for
        # the first param, because this helps with casting in load_state_dict
        state = self.state[self._params[0]]
        state.setdefault('func_evals', 0)
        state.setdefault('n_iter', 0)

        # evaluate initial f(x) and df/dx
        orig_loss = closure()
        loss = float(orig_loss)
        current_evals = 1
        state['func_evals'] += 1

        flat_grad = self._gather_flat_grad()
        opt_cond = flat_grad.abs().max() <= tolerance_grad

        # optimal condition
        if opt_cond:
            return orig_loss

        # tensors cached in state (for tracing)
        d = state.get('d')
        t = state.get('t')
        old_dirs = state.get('old_dirs')
        old_stps = state.get('old_stps')
        ro = state.get('ro')
        H_diag = state.get('H_diag')
        prev_flat_grad = state.get('prev_flat_grad')
        prev_loss = state.get('prev_loss')

        n_iter = 0
        # optimize for a max of max_iter iterations
        while n_iter < max_iter:
            # keep track of nb of iterations
            n_iter += 1
            state['n_iter'] += 1

            ############################################################
            # compute gradient descent direction
            ############################################################
            if state['n_iter'] == 1:
                d = flat_grad.neg()
                old_dirs = []
                old_stps = []
                ro = []
                H_diag = 1
            else:
                # do lbfgs update (update memory)
                y = flat_grad.sub(prev_flat_grad)
                s = d.mul(t)
                ys = y.dot(s)  # y*s
                if ys > tolerance_ys:
                    # updating memory
                    if len(old_dirs) == history_size:
                        # shift history by one (limited-memory)
                        old_dirs.pop(0)
                        old_stps.pop(0)
                        ro.pop(0)

                    # store new direction/step
                    old_dirs.append(y)
                    old_stps.append(s)
                    ro.append(1. / ys)

                    # update scale of initial Hessian approximation
                    H_diag = ys / y.dot(y)  # (y*y)

                # compute the approximate (L-BFGS) inverse Hessian
                # multiplied by the gradient
                num_old = len(old_dirs)

                if 'al' not in state:
                    state['al'] = [None] * history_size
                al = state['al']

                # iteration in L-BFGS loop collapsed to use just one buffer
                q = flat_grad.neg()
                for i in range(num_old - 1, -1, -1):
                    al[i] = old_stps[i].dot(q) * ro[i]
                    q.add_(old_dirs[i], alpha=-al[i])

                # multiply by initial Hessian
                # r/d is the final direction
                d = r = torch.mul(q, H_diag)
                for i in range(num_old):
                    be_i = old_dirs[i].dot(r) * ro[i]
                    r.add_(old_stps[i], alpha=al[i] - be_i)

            if prev_flat_grad is None:
                prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
            else:
                prev_flat_grad.copy_(flat_grad)
            prev_loss = loss

            ############################################################
            # compute step length
            ############################################################
            # reset initial guess for step size
            if state['n_iter'] == 1:
                t = min(1., 1. / flat_grad.abs().sum()) * lr
            else:
                t = lr

            # directional derivative
            gtd = flat_grad.dot(d)  # g * d

            # directional derivative is below tolerance
            if gtd > -tolerance_change:
                break

            # optional line search: user function
            ls_func_evals = 0
            if line_search_fn is not None:
                # perform line search, using user function
                if line_search_fn != "strong_wolfe":
                    raise RuntimeError("only 'strong_wolfe' is supported")
                else:
                    x_init = self._clone_param()

                    def obj_func(x, t, d):
                        return self._directional_evaluate(closure, x, t, d)
                    loss, flat_grad, t, ls_func_evals = _strong_wolfe(
                        obj_func, x_init, t, d, loss, flat_grad, gtd)
                self._add_grad(t, d)
                opt_cond = flat_grad.abs().max() <= tolerance_grad
            else:
                # no line search, simply move with fixed-step
                self._add_grad(t, d)
                if n_iter != max_iter:
                    # re-evaluate function only if not in last iteration
                    # the reason we do this: in a stochastic setting,
                    # no use to re-evaluate that function here
                    with torch.enable_grad():
                        loss = float(closure())
                    flat_grad = self._gather_flat_grad()
                    opt_cond = flat_grad.abs().max() <= tolerance_grad
                    ls_func_evals = 1

            # update func eval
            current_evals += ls_func_evals
            state['func_evals'] += ls_func_evals

            ############################################################
            # check conditions
            ############################################################
            if n_iter == max_iter:
                break

            if current_evals >= max_eval:
                break

            # optimal condition
            if opt_cond:
                break

            # lack of progress
            if d.mul(t).abs().max() <= tolerance_change:
                break

            if abs(loss - prev_loss) < tolerance_change:
                break

        state['d'] = d
        state['t'] = t
        state['old_dirs'] = old_dirs
        state['old_stps'] = old_stps
        state['ro'] = ro
        state['H_diag'] = H_diag
        state['prev_flat_grad'] = prev_flat_grad
        state['prev_loss'] = prev_loss

        return orig_loss


四、总结与思考

        KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。

        1. KAN网络架构

  • 关键设计可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。

  • 示例结构输入层 → 隐藏层:每个输入节点通过单变量函数\phi_{q,i} \left( x_{i} \right) 连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数\psi_{q}组合得到输出。

        2. 优势与特点

  • 高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。

  • 可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。

  • 灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。

        3. 挑战与局限

  • 计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。

  • 泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。

  • 训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。

        4. 应用场景

  • 科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。

  • 可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。

  • 符号回归:从数据中自动发现数学公式(如物理定律)。

        5. 与传统MLP的对比

        6. 研究进展

  • 近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。

  • 开源实现:已有PyTorch等框架的初步实现。


【作者声明】

        本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。


 【关注我们】

        如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!

猜你喜欢

转载自blog.csdn.net/2303_77200324/article/details/146424685
今日推荐