【KAN】KAN神经网络学习训练营(4)——feynman.py

一、引言

        KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。


二、技术与原理简介

        1.Kolmogorov-Arnold 表示定理

         Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑ff:[0,1]^{^{n}}\rightarrow \mathbb{R}

f \left( x \right)=f \left( x_{1}, \cdots,x_{n} \right)= \sum_{q=1}^{2n+1} \Phi_{q} \left( \sum_{p=1}^{n} \phi_{q,p} \left( x_{p} \right) \right)

        其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。\boldsymbol{\phi_{q,p}:[0,1]\to\mathbb{R}}\boldsymbol{\Phi_{q}:\mathbb{R}\to\mathbb{R}(2n+1)}

        2.Kolmogorov-Arnold 网络 (KAN)

        Kolmogorov-Arnold 表示可以写成矩阵形式

f(x)=\mathbf{\Phi_{out}}\mathsf{o}\mathbf{\Phi_{in}}\mathsf{o}{}x

其中

\mathbf{\Phi}_{\mathrm{in}}=\begin{pmatrix}\phi_{1,1}(\cdot)&\cdots&\phi_{1,n }(\cdot)\\ \vdots&&\vdots\\ \phi_{2n+1,1}(\cdot)&\cdots&\phi_{2n+1,n}(\cdot)\end{pmatrix}

\quad\mathbf{ \Phi}_{\mathrm{out}}=\left(\Phi_{1}(\cdot)\quad\cdots\quad\Phi_{2n+1}(\cdot)\right)

        我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:\mathbf{\Phi_{in}} \mathbf{\Phi_{out}} \mathbf{\Phi_{n_{in}n_{out}}}

其中\boldsymbol{n_{\text{in}}=n,n_{\text{out}}=2n+1\Phi_{\text{out}}n_{\text{in}}=2n+1,n_{\text{out}}=1}

        定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是Ll^{th} \Phi_{l} \left( n_{l+1},n_{l} \right)

\mathbf{KAN(x)}=\mathbf{\Phi_{L-1}}\circ\cdots\circ\mathbf{\Phi_{1}}\circ \mathbf{\Phi_{0}}\circ\mathbf{x}

        相反,多层感知器由线性层和非线错:\mathbf{W}_{l^{\sigma}}

\text{MLP}(\mathbf{x})=\mathbf{W}_{\textit{L-1}}\circ\sigma\circ\cdots\circ \mathbf{W}_{1}\circ\sigma\circ\mathbf{W}_{0}\circ\mathbf{x}

        KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。


三、代码详解

        该代码用于生成不同费曼物理方程对应的符号表达式、PyTorch计算函数及变量取值范围,主要用于符号回归(Symbolic Regression)任务的基准测试。

        A.代码详解

        1. 全局引入与基础设置

from sympy import *
import torch

def get_feynman_dataset(name):
    global symbols  # 声明使用全局符号变量(需在外部定义)
    tpi = torch.tensor(torch.pi)  # 预计算π的Tensor常量
  • 功能:引入SymPy符号计算库和PyTorch张量库,定义全局符号变量并预存π值。
  • 说明tpi用于后续所有需要π的PyTorch计算,避免重复转换。

        2. 测试用例

    if name == 'test':
        symbol = x, y = symbols('x, y')
        expr = (x+y) * sin(exp(2*y))
        f = lambda x: (x[:,[0]] + x[:,[1]])*torch.sin(torch.exp(2*x[:,[1]]))
        ranges = [-1,1]
  • 公式:测试公式 (x+y) * sin(e^{2y})
  • 符号变量x, y
  • PyTorch实现:张量切片取各变量,计算表达式。
  • 变量范围:所有变量在[-1, 1]区间均匀采样。

        3. 高斯分布相关公式

        I.6.20a(单变量正态分布)
    if name == 'I.6.20a' or name == 1:
        symbol = theta = symbols('theta')
        symbol = [symbol]
        expr = exp(-theta**2/2)/sqrt(2*pi)
        f = lambda x: torch.exp(-x[:,[0]]**2/2)/torch.sqrt(2*tpi)
        ranges = [[-3,3]]
  • 公式:标准正态分布概率密度函数 P(θ) = e^{-θ²/2}/√(2π)
  • 变量θ(范围[-3, 3])。
        I.6.20(带标准差的正态分布)
    if name == 'I.6.20' or name == 2:
        symbol = theta, sigma = symbols('theta sigma')
        expr = exp(-theta**2/(2*sigma**2))/sqrt(2*pi*sigma**2)
        f = lambda x: torch.exp(-x[:,[0]]**2/(2*x[:,[1]]**2))/torch.sqrt(2*tpi*x[:,[1]]**2)
        ranges = [[-1,1],[0.5,2]]
  • 公式:一般正态分布 P(θ,σ) = e^{-θ²/(2σ²)}/√(2πσ²)
  • 变量θ([-1,1]),σ([0.5,2],避免接近零导致除零错误)。
        I.6.20b(带均值与标准差的正态分布)
    if name == 'I.6.20b' or name == 3:
        symbol = theta, theta1, sigma = symbols('theta theta1 sigma')
        expr = exp(-(theta-theta1)**2/(2*sigma**2))/sqrt(2*pi*sigma^2)
        f = lambda x: torch.exp(-(x[:,[0]]-x[:,[1]])**2/(2*x[:,[2]]**2))/torch.sqrt(2*tpi*x[:,[2]]**2)
        ranges = [[-1.5,1.5],[-1.5,1.5],[0.5,2]]
  • 公式:含均值θ₁的正态分布 P(θ,θ₁,σ) = e^{-(θ-θ₁)²/(2σ²)}/√(2πσ²)
  • 变量θθ₁([-1.5,1.5]),σ([0.5,2])。

        4. 几何与力学公式

        I.8.4(两点间距离)
    if name == 'I.8.4' or name == 4:
        symbol = x1, x2, y1, y2 = symbols('x1 x2 y1 y2')
        expr = sqrt((x2-x1)**2+(y2-y1)**2)
        f = lambda x: torch.sqrt((x[:,[1]]-x[:,[0]])**2+(x[:,[3]]-x[:,[2]])**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,1]]
  • 公式:二维空间两点距离公式 d = √[(x2-x1)² + (y2-y1)²]
  • 变量:四点坐标均在[-1,1]区间。
        I.9.18(万有引力公式)
    if name == 'I.9.18' or name == 5:
        symbol = G, m1, m2, x1, x2, y1, y2, z1, z2 = symbols('G m1 m2 x1 x2 y1 y2 z1 z2')
        expr = G*m1*m2/((x2-x1)**2+(y2-y1)^2+(z2-z1)^2)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/((x[:,[3]]-x[:,[4]])**2+(x[:,[5]]-x[:,[6]])**2+(x[:,[7]]-x[:,[8]])**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1]]
  • 公式:三维万有引力公式 F = Gm₁m₂/[(Δx)² + (Δy)² + (Δz)²]
  • 变量Gm₁m₂([-1,1]),坐标范围错开以避免分母过小。

        5. 相对论与电磁学

        I.10.7(相对论质量公式)
    if name == 'I.10.7' or name == 6:
        symbol = m0, v, c = symbols('m0 v c')
        expr = m0/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]^2)
        ranges = [[0,1],[0,1],[1,2]]
  • 公式:相对论质量公式 m = m₀/√(1 - v²/c²)
  • 变量m₀(静止质量,[0,1]),v([0,1]),c([1,2]确保v < c)。
        I.12.2(库仑定律)
    if name == 'I.12.2' or name == 9:
        symbol = q1, q2, eps, r = symbols('q1 q2 epsilon r')
        expr = q1*q2/(4*pi*eps*r**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]]*x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]]
  • 公式:库仑力公式 F = q₁q₂/(4πεr²)
  • 变量q₁q₂(可正负),εr避免零值。

        B.完整代码

from sympy import *
import torch


def get_feynman_dataset(name):
    
    global symbols
    
    tpi = torch.tensor(torch.pi)
    
    if name == 'test':
        symbol = x, y = symbols('x, y')
        expr = (x+y) * sin(exp(2*y))
        f = lambda x: (x[:,[0]] + x[:,[1]])*torch.sin(torch.exp(2*x[:,[1]]))
        ranges = [-1,1]
    
    if name == 'I.6.20a' or name == 1:
        symbol = theta = symbols('theta')
        symbol = [symbol]
        expr = exp(-theta**2/2)/sqrt(2*pi)
        f = lambda x: torch.exp(-x[:,[0]]**2/2)/torch.sqrt(2*tpi)
        ranges = [[-3,3]]
    
    if name == 'I.6.20' or name == 2:
        symbol = theta, sigma = symbols('theta sigma')
        expr = exp(-theta**2/(2*sigma**2))/sqrt(2*pi*sigma**2)
        f = lambda x: torch.exp(-x[:,[0]]**2/(2*x[:,[1]]**2))/torch.sqrt(2*tpi*x[:,[1]]**2)
        ranges = [[-1,1],[0.5,2]]
        
    if name == 'I.6.20b' or name == 3:
        symbol = theta, theta1, sigma = symbols('theta theta1 sigma')
        expr = exp(-(theta-theta1)**2/(2*sigma**2))/sqrt(2*pi*sigma**2)
        f = lambda x: torch.exp(-(x[:,[0]]-x[:,[1]])**2/(2*x[:,[2]]**2))/torch.sqrt(2*tpi*x[:,[2]]**2)
        ranges = [[-1.5,1.5],[-1.5,1.5],[0.5,2]]
    
    if name == 'I.8.4' or name == 4:
        symbol = x1, x2, y1, y2 = symbols('x1 x2 y1 y2')
        expr = sqrt((x2-x1)**2+(y2-y1)**2)
        f = lambda x: torch.sqrt((x[:,[1]]-x[:,[0]])**2+(x[:,[3]]-x[:,[2]])**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,1]]
        
    if name == 'I.9.18' or name == 5:
        symbol = G, m1, m2, x1, x2, y1, y2, z1, z2 = symbols('G m1 m2 x1 x2 y1 y2 z1 z2')
        expr = G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/((x[:,[3]]-x[:,[4]])**2+(x[:,[5]]-x[:,[6]])**2+(x[:,[7]]-x[:,[8]])**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1]]
        
    if name == 'I.10.7' or name == 6:
        symbol = m0, v, c = symbols('m0 v c')
        expr = m0/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[0,1],[1,2]]
        
    if name == 'I.11.19' or name == 7:
        symbol = x1, y1, x2, y2, x3, y3 = symbols('x1 y1 x2 y2 x3 y3')
        expr = x1*y1 + x2*y2 + x3*y3
        f = lambda x: x[:,[0]]*x[:,[1]] + x[:,[2]]*x[:,[3]] + x[:,[4]]*x[:,[5]]
        ranges = [-1,1]
    
    if name == 'I.12.1' or name == 8:
        symbol = mu, Nn = symbols('mu N_n')
        expr = mu * Nn
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [-1,1]
        
    if name == 'I.12.2' or name == 9:
        symbol = q1, q2, eps, r = symbols('q1 q2 epsilon r')
        expr = q1*q2/(4*pi*eps*r**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]]*x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]]
        
    if name == 'I.12.4' or name == 10:
        symbol = q1, eps, r = symbols('q1 epsilon r')
        expr = q1/(4*pi*eps*r**2)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]**2)
        ranges = [[-1,1],[0.5,2],[0.5,2]]
        
    if name == 'I.12.5' or name == 11:
        symbol = q2, Ef = symbols('q2, E_f')
        expr = q2*Ef
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [-1,1]
        
    if name == 'I.12.11' or name == 12:
        symbol = q, Ef, B, v, theta = symbols('q E_f B v theta')
        expr = q*(Ef + B*v*sin(theta))
        f = lambda x: x[:,[0]]*(x[:,[1]]+x[:,[2]]*x[:,[3]]*torch.sin(x[:,[4]]))
        ranges = [[-1,1],[-1,1],[-1,1],[-1,1],[0,2*tpi]]
        
    if name == 'I.13.4' or name == 13:
        symbol = m, v, u, w = symbols('m u v w')
        expr = 1/2*m*(v**2+u**2+w**2)
        f = lambda x: 1/2*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,1]]
        
    if name == 'I.13.12' or name == 14:
        symbol = G, m1, m2, r1, r2 = symbols('G m1 m2 r1 r2')
        expr = G*m1*m2*(1/r2-1/r1)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*(1/x[:,[4]]-1/x[:,[3]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'I.14.3' or name == 15:
        symbol = m, g, z = symbols('m g z')
        expr = m*g*z
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]
        ranges = [[0,1],[0,1],[-1,1]]
        
    if name == 'I.14.4' or name == 16:
        symbol = ks, x = symbols('k_s x')
        expr = 1/2*ks*x**2
        f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2
        ranges = [[0,1],[-1,1]]
        
    if name == 'I.15.3x' or name == 17:
        symbol = x, u, t, c = symbols('x u t c')
        expr = (x-u*t)/sqrt(1-u**2/c**2)
        f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[-1,1],[1,2]]
        
    if name == 'I.15.3t' or name == 18:
        symbol = t, u, x, c = symbols('t u x c')
        expr = (t-u*x/c**2)/sqrt(1-u**2/c**2)
        f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]]/x[:,[3]]**2)/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[-1,1],[1,2]]
        
    if name == 'I.15.10' or name == 19:
        symbol = m0, v, c = symbols('m0 v c')
        expr = m0*v/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[-1,1],[-0.9,0.9],[1.1,2]]
        
    if name == 'I.16.6' or name == 20:
        symbol = u, v, c = symbols('u v c')
        expr = (u+v)/(1+u*v/c**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/(1+x[:,[0]]*x[:,[1]]/x[:,[2]]**2)
        ranges = [[-0.8,0.8],[-0.8,0.8],[1,2]]
        
    if name == 'I.18.4' or name == 21:
        symbol = m1, r1, m2, r2 = symbols('m1 r1 m2 r2')
        expr = (m1*r1+m2*r2)/(m1+m2)
        f = lambda x: (x[:,[0]]*x[:,[1]]+x[:,[2]]*x[:,[3]])/(x[:,[0]]+x[:,[2]])
        ranges = [[0.5,1],[-1,1],[0.5,1],[-1,1]]
        
    if name == 'I.18.4' or name == 22:
        symbol = r, F, theta = symbols('r F theta')
        expr = r*F*sin(theta)
        f = lambda x: x[:,[0]]*x[:,[1]]*torch.sin(x[:,[2]])
        ranges = [[-1,1],[-1,1],[0,2*tpi]]
        
    if name == 'I.18.16' or name == 23:
        symbol = m, r, v, theta = symbols('m r v theta')
        expr = m*r*v*sin(theta)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.sin(x[:,[3]])
        ranges = [[-1,1],[-1,1],[-1,1],[0,2*tpi]]
        
    if name == 'I.24.6' or name == 24:
        symbol = m, omega, omega0, x = symbols('m omega omega_0 x')
        expr = 1/4*m*(omega**2+omega0**2)*x**2
        f = lambda x: 1/4*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2)*x[:,[3]]**2
        ranges = [[0,1],[-1,1],[-1,1],[-1,1]]
        
    if name == 'I.25.13' or name == 25:
        symbol = q, C = symbols('q C')
        expr = q/C
        f = lambda x: x[:,[0]]/x[:,[1]]
        ranges = [[-1,1],[0.5,2]]
        
    if name == 'I.26.2' or name == 26:
        symbol = n, theta2 = symbols('n theta2')
        expr = asin(n*sin(theta2))
        f = lambda x: torch.arcsin(x[:,[0]]*torch.sin(x[:,[1]]))
        ranges = [[0,0.99],[0,2*tpi]]
        
    if name == 'I.27.6' or name == 27:
        symbol = d1, d2, n = symbols('d1 d2 n')
        expr = 1/(1/d1+n/d2)
        f = lambda x: 1/(1/x[:,[0]]+x[:,[2]]/x[:,[1]])
        ranges = [[0.5,2],[1,2],[0.5,2]]
    
    if name == 'I.29.4' or name == 28:
        symbol = omega, c = symbols('omega c')
        expr = omega/c
        f = lambda x: x[:,[0]]/x[:,[1]]
        ranges = [[0,1],[0.5,2]]
        
    if name == 'I.29.16' or name == 29:
        symbol = x1, x2, theta1, theta2 = symbols('x1 x2 theta1 theta2')
        expr = sqrt(x1**2+x2**2-2*x1*x2*cos(theta1-theta2))
        f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2-2*x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]]))
        ranges = [[-1,1],[-1,1],[0,2*tpi],[0,2*tpi]]
        
    if name == 'I.30.3' or name == 30:
        symbol = I0, n, theta = symbols('I_0 n theta')
        expr = I0 * sin(n*theta/2)**2 / sin(theta/2) ** 2
        f = lambda x: x[:,[0]] * torch.sin(x[:,[1]]*x[:,[2]]/2)**2 / torch.sin(x[:,[2]]/2)**2
        ranges = [[0,1],[0,4],[0.4*tpi,1.6*tpi]]
        
    if name == 'I.30.5' or name == 31:
        symbol = lamb, n, d = symbols('lambda n d')
        expr = asin(lamb/(n*d))
        f = lambda x: torch.arcsin(x[:,[0]]/(x[:,[1]]*x[:,[2]]))
        ranges = [[-1,1],[1,1.5],[1,1.5]]
        
    if name == 'I.32.5' or name == 32:
        symbol = q, a, eps, c = symbols('q a epsilon c')
        expr = q**2*a**2/(eps*c**3)
        f = lambda x: x[:,[0]]**2*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**3)
        ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]]
        
    if name == 'I.32.17' or name == 33:
        symbol = eps, c, Ef, r, omega, omega0 = symbols('epsilon c E_f r omega omega_0')
        expr = nsimplify((1/2*eps*c*Ef**2)*(8*pi*r**2/3)*(omega**4/(omega**2-omega0**2)**2))
        f = lambda x: (1/2*x[:,[0]]*x[:,[1]]*x[:,[2]]**2)*(8*tpi*x[:,[3]]**2/3)*(x[:,[4]]**4/(x[:,[4]]**2-x[:,[5]]**2)**2)
        ranges = [[0,1],[0,1],[-1,1],[0,1],[0,1],[1,2]]
        
    if name == 'I.34.8' or name == 34:
        symbol = q, V, B, p = symbols('q V B p')
        expr = q*V*B/p
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[-1,1],[-1,1],[-1,1],[0.5,2]]
        
    if name == 'I.34.10' or name == 35:
        symbol = omega0, v, c = symbols('omega_0 v c')
        expr = omega0/(1-v/c)
        f = lambda x: x[:,[0]]/(1-x[:,[1]]/x[:,[2]])
        ranges = [[0,1],[0,0.9],[1.1,2]]
        
    if name == 'I.34.14' or name == 36:
        symbol = omega0, v, c = symbols('omega_0 v c')
        expr = omega0 * (1+v/c)/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*(1+x[:,[1]]/x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[-0.9,0.9],[1.1,2]]
        
    if name == 'I.34.27' or name == 37:
        symbol = hbar, omega = symbols('hbar omega')
        expr = hbar * omega
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [[-1,1],[-1,1]]
        
    if name == 'I.37.4' or name == 38:
        symbol = I1, I2, delta = symbols('I_1 I_2 delta')
        expr = I1 + I2 + 2*sqrt(I1*I2)*cos(delta)
        f = lambda x: x[:,[0]] + x[:,[1]] + 2*torch.sqrt(x[:,[0]]*x[:,[1]])*torch.cos(x[:,[2]])
        ranges = [[0.1,1],[0.1,1],[0,2*tpi]]
        
    if name == 'I.38.12' or name == 39:
        symbol = eps, hbar, m, q = symbols('epsilon hbar m q')
        expr = 4*pi*eps*hbar**2/(m*q**2)
        f = lambda x: 4*tpi*x[:,[0]]*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**2)
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'I.39.10' or name == 40:
        symbol = pF, V = symbols('p_F V')
        expr = 3/2 * pF * V
        f = lambda x: 3/2 * x[:,[0]] * x[:,[1]]
        ranges = [[0,1],[0,1]]
        
    if name == 'I.39.11' or name == 41:
        symbol = gamma, pF, V = symbols('gamma p_F V')
        expr = pF * V/(gamma - 1)
        f = lambda x: 1/(x[:,[0]]-1) * x[:,[1]] * x[:,[2]]
        ranges = [[1.5,3],[0,1],[0,1]]
        
    if name == 'I.39.22' or name == 42:
        symbol = n, kb, T, V = symbols('n k_b T V')
        expr = n*kb*T/V
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'I.40.1' or name == 43:
        symbol = n0, m, g, x, kb, T = symbols('n_0 m g x k_b T')
        expr = n0 * exp(-m*g*x/(kb*T))
        f = lambda x: x[:,[0]] * torch.exp(-x[:,[1]]*x[:,[2]]*x[:,[3]]/(x[:,[4]]*x[:,[5]]))
        ranges = [[0,1],[-1,1],[-1,1],[-1,1],[1,2],[1,2]]
        
    if name == 'I.41.16' or name == 44:
        symbol = hbar, omega, c, kb, T = symbols('hbar omega c k_b T')
        expr = hbar * omega**3/(pi**2*c**2*(exp(hbar*omega/(kb*T))-1))
        f = lambda x: x[:,[0]]*x[:,[1]]**3/(tpi**2*x[:,[2]]**2*(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[3]]*x[:,[4]]))-1))
        ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'I.43.16' or name == 45:
        symbol = mu, q, Ve, d = symbols('mu q V_e d')
        expr = mu*q*Ve/d
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'I.43.31' or name == 46:
        symbol = mu, kb, T = symbols('mu k_b T')
        expr = mu*kb*T
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]
        ranges = [[0,1],[0,1],[0,1]]
    
    if name == 'I.43.43' or name == 47:
        symbol = gamma, kb, v, A = symbols('gamma k_b v A')
        expr = kb*v/A/(gamma-1)
        f = lambda x: 1/(x[:,[0]]-1)*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[1.5,3],[0,1],[0,1],[0.5,2]]
        
    if name == 'I.44.4' or name == 48:
        symbol = n, kb, T, V1, V2 = symbols('n k_b T V_1 V_2')
        expr = n*kb*T*log(V2/V1)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.log(x[:,[4]]/x[:,[3]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'I.47.23' or name == 49:
        symbol = gamma, p, rho = symbols('gamma p rho')
        expr = sqrt(gamma*p/rho)
        f = lambda x: torch.sqrt(x[:,[0]]*x[:,[1]]/x[:,[2]])
        ranges = [[0.1,1],[0.1,1],[0.5,2]]
        
    if name == 'I.48.20' or name == 50:
        symbol = m, v, c = symbols('m v c')
        expr = m*c**2/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*x[:,[2]]**2/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[-0.9,0.9],[1.1,2]]
        
    if name == 'I.50.26' or name == 51:
        symbol = x1, alpha, omega, t = symbols('x_1 alpha omega t')
        expr = x1*(cos(omega*t)+alpha*cos(omega*t)**2)
        f = lambda x: x[:,[0]]*(torch.cos(x[:,[2]]*x[:,[3]])+x[:,[1]]*torch.cos(x[:,[2]]*x[:,[3]])**2)
        ranges = [[0,1],[0,1],[0,2*tpi],[0,1]]
        
    if name == 'II.2.42' or name == 52:
        symbol = kappa, T1, T2, A, d = symbols('kappa T_1 T_2 A d')
        expr = kappa*(T2-T1)*A/d
        f = lambda x: x[:,[0]]*(x[:,[2]]-x[:,[1]])*x[:,[3]]/x[:,[4]]
        ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.3.24' or name == 53:
        symbol = P, r = symbols('P r')
        expr = P/(4*pi*r**2)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]**2)
        ranges = [[0,1],[0.5,2]]
        
    if name == 'II.4.23' or name == 54:
        symbol = q, eps, r = symbols('q epsilon r')
        expr = q/(4*pi*eps*r)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]])
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.6.11' or name == 55:
        symbol = eps, pd, theta, r = symbols('epsilon p_d theta r')
        expr = 1/(4*pi*eps)*pd*cos(theta)/r**2
        f = lambda x: 1/(4*tpi*x[:,[0]])*x[:,[1]]*torch.cos(x[:,[2]])/x[:,[3]]**2
        ranges = [[0.5,2],[0,1],[0,2*tpi],[0.5,2]]
        
    if name == 'II.6.15a' or name == 56:
        symbol = eps, pd, z, x, y, r = symbols('epsilon p_d z x y r')
        expr = 3/(4*pi*eps)*pd*z/r**5*sqrt(x**2+y**2)
        f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]*x[:,[2]]/x[:,[5]]**5*torch.sqrt(x[:,[3]]**2+x[:,[4]]**2)
        ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]]
    
    if name == 'II.6.15b' or name == 57:
        symbol = eps, pd, r, theta = symbols('epsilon p_d r theta')
        expr = 3/(4*pi*eps)*pd/r**3*cos(theta)*sin(theta)
        f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]/x[:,[2]]**3*torch.cos(x[:,[3]])*torch.sin(x[:,[3]])
        ranges = [[0.5,2],[0,1],[0.5,2],[0,2*tpi]]
        
    if name == 'II.8.7' or name == 58:
        symbol = q, eps, d = symbols('q epsilon d')
        expr = 3/5*q**2/(4*pi*eps*d)
        f = lambda x: 3/5*x[:,[0]]**2/(4*tpi*x[:,[1]]*x[:,[2]])
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.8.31' or name == 59:
        symbol = eps, Ef = symbols('epsilon E_f')
        expr = 1/2*eps*Ef**2
        f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2
        ranges = [[0,1],[0,1]]
        
    if name == 'I.10.9' or name == 60:
        symbol = sigma, eps, chi = symbols('sigma epsilon chi')
        expr = sigma/eps/(1+chi)
        f = lambda x: x[:,[0]]/x[:,[1]]/(1+x[:,[2]])
        ranges = [[0,1],[0.5,2],[0,1]]
        
    if name == 'II.11.3' or name == 61:
        symbol = q, Ef, m, omega0, omega = symbols('q E_f m omega_o omega')
        expr = q*Ef/(m*(omega0**2-omega**2))
        f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*(x[:,[3]]**2-x[:,[4]]**2))
        ranges = [[0,1],[0,1],[0.5,2],[1.5,3],[0,1]]
        
    if name == 'II.11.17' or name == 62:
        symbol = n0, pd, Ef, theta, kb, T = symbols('n_0 p_d E_f theta k_b T')
        expr = n0*(1+pd*Ef*cos(theta)/(kb*T))
        f = lambda x: x[:,[0]]*(1+x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]])/(x[:,[4]]*x[:,[5]]))
        ranges = [[0,1],[-1,1],[-1,1],[0,2*tpi],[0.5,2],[0.5,2]]
        
        
    if name == 'II.11.20' or name == 63:
        symbol = n, pd, Ef, kb, T = symbols('n p_d E_f k_b T')
        expr = n*pd**2*Ef/(3*kb*T)
        f = lambda x: x[:,[0]]*x[:,[1]]**2*x[:,[2]]/(3*x[:,[3]]*x[:,[4]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.11.27' or name == 64:
        symbol = n, alpha, eps, Ef = symbols('n alpha epsilon E_f')
        expr = n*alpha/(1-n*alpha/3)*eps*Ef
        f = lambda x: x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3)*x[:,[2]]*x[:,[3]]
        ranges = [[0,1],[0,2],[0,1],[0,1]]
        
    if name == 'II.11.28' or name == 65:
        symbol = n, alpha = symbols('n alpha')
        expr = 1 + n*alpha/(1-n*alpha/3)
        f = lambda x: 1 + x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3)
        ranges = [[0,1],[0,2]]
        
    if name == 'II.13.17' or name == 66:
        symbol = eps, c, l, r = symbols('epsilon c l r')
        expr = 1/(4*pi*eps*c**2)*(2*l/r)
        f = lambda x: 1/(4*tpi*x[:,[0]]*x[:,[1]]**2)*(2*x[:,[2]]/x[:,[3]])
        ranges = [[0.5,2],[0.5,2],[0,1],[0.5,2]]
        
    if name == 'II.13.23' or name == 67:
        symbol = rho, v, c = symbols('rho v c')
        expr = rho/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[0,1],[1,2]]
        
    if name == 'II.13.34' or name == 68:
        symbol = rho, v, c = symbols('rho v c')
        expr = rho*v/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[0,1],[1,2]]
        
    if name == 'II.15.4' or name == 69:
        symbol = muM, B, theta = symbols('mu_M B theta')
        expr = - muM * B * cos(theta)
        f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]])
        ranges = [[0,1],[0,1],[0,2*tpi]]
        
    if name == 'II.15.5' or name == 70:
        symbol = pd, Ef, theta = symbols('p_d E_f theta')
        expr = - pd * Ef * cos(theta)
        f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]])
        ranges = [[0,1],[0,1],[0,2*tpi]]
        
    if name == 'II.21.32' or name == 71:
        symbol = q, eps, r, v, c = symbols('q epsilon r v c')
        expr = q/(4*pi*eps*r*(1-v/c))
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]*(1-x[:,[3]]/x[:,[4]]))
        ranges = [[0,1],[0.5,2],[0.5,2],[0,1],[1,2]]
        
    if name == 'II.24.17' or name == 72:
        symbol = omega, c, d = symbols('omega c d')
        expr = sqrt(omega**2/c**2-pi**2/d**2)
        f = lambda x: torch.sqrt(x[:,[0]]**2/x[:,[1]]**2-tpi**2/x[:,[2]]**2)
        ranges = [[1,1.5],[0.75,1],[1*tpi,1.5*tpi]]
        
    if name == 'II.27.16' or name == 73:
        symbol = eps, c, Ef = symbols('epsilon c E_f')
        expr = eps * c * Ef**2
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]**2
        ranges = [[0,1],[0,1],[-1,1]]
        
    if name == 'II.27.18' or name == 74:
        symbol = eps, Ef = symbols('epsilon E_f')
        expr = eps * Ef**2
        f = lambda x: x[:,[0]]*x[:,[1]]**2
        ranges = [[0,1],[-1,1]]
        
    if name == 'II.34.2a' or name == 75:
        symbol = q, v, r = symbols('q v r')
        expr = q*v/(2*pi*r)
        f = lambda x: x[:,[0]]*x[:,[1]]/(2*tpi*x[:,[2]])
        ranges = [[0,1],[0,1],[0.5,2]]
        
    if name == 'II.34.2' or name == 76:
        symbol = q, v, r = symbols('q v r')
        expr = q*v*r/2
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/2
        ranges = [[0,1],[0,1],[0,1]]
        
    if name == 'II.34.11' or name == 77:
        symbol = g, q, B, m = symbols('g q B m')
        expr = g*q*B/(2*m)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/(2*x[:,[3]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.34.29a' or name == 78:
        symbol = q, h, m = symbols('q h m')
        expr = q*h/(4*pi*m)
        f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]])
        ranges = [[0,1],[0,1],[0.5,2]]
        
    if name == 'II.34.29b' or name == 79:
        symbol = g, mu, B, J, hbar = symbols('g mu B J hbar')
        expr = g*mu*B*J/hbar
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]/x[:,[4]]
        ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.35.18' or name == 80:
        symbol = n0, mu, B, kb, T = symbols('n0 mu B k_b T')
        expr = n0/(exp(mu*B/(kb*T))+exp(-mu*B/(kb*T)))
        f = lambda x: x[:,[0]]/(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))+torch.exp(-x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]])))
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.35.21' or name == 81:
        symbol = n, mu, B, kb, T = symbols('n mu B k_b T')
        expr = n*mu*tanh(mu*B/(kb*T))
        f = lambda x: x[:,[0]]*x[:,[1]]*torch.tanh(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.36.38' or name == 82:
        symbol = mu, B, kb, T, alpha, M, eps, c = symbols('mu B k_b T alpha M epsilon c')
        expr = mu*B/(kb*T) + mu*alpha*M/(eps*c**2*kb*T)
        f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]) + x[:,[0]]*x[:,[4]]*x[:,[5]]/(x[:,[6]]*x[:,[7]]**2*x[:,[2]]*x[:,[3]])
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.37.1' or name == 83:
        symbol = mu, chi, B = symbols('mu chi B')
        expr = mu*(1+chi)*B
        f = lambda x: x[:,[0]]*(1+x[:,[1]])*x[:,[2]]
        ranges = [[0,1],[0,1],[0,1]]
        
    if name == 'II.38.3' or name == 84:
        symbol = Y, A, x, d = symbols('Y A x d')
        expr = Y*A*x/d
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.38.14' or name == 85:
        symbol = Y, sigma = symbols('Y sigma')
        expr = Y/(2*(1+sigma))
        f = lambda x: x[:,[0]]/(2*(1+x[:,[1]]))
        ranges = [[0,1],[0,1]]
        
    if name == 'III.4.32' or name == 86:
        symbol = hbar, omega, kb, T = symbols('hbar omega k_b T')
        expr = 1/(exp(hbar*omega/(kb*T))-1)
        f = lambda x: 1/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1)
        ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2]]
        
    if name == 'III.4.33' or name == 87:
        symbol = hbar, omega, kb, T = symbols('hbar omega k_b T')
        expr = hbar*omega/(exp(hbar*omega/(kb*T))-1)
        f = lambda x: x[:,[0]]*x[:,[1]]/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1)
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.7.38' or name == 88:
        symbol = mu, B, hbar = symbols('mu B hbar')
        expr = 2*mu*B/hbar
        f = lambda x: 2*x[:,[0]]*x[:,[1]]/x[:,[2]]
        ranges = [[0,1],[0,1],[0.5,2]]
        
    if name == 'III.8.54' or name == 89:
        symbol = E, t, hbar = symbols('E t hbar')
        expr = sin(E*t/hbar)**2
        f = lambda x: torch.sin(x[:,[0]]*x[:,[1]]/x[:,[2]])**2
        ranges = [[0,2*tpi],[0,1],[0.5,2]]
        
    if name == 'III.9.52' or name == 90:
        symbol = pd, Ef, t, hbar, omega, omega0 = symbols('p_d E_f t hbar omega omega_0')
        expr = pd*Ef*t/hbar*sin((omega-omega0)*t/2)**2/((omega-omega0)*t/2)**2
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]*torch.sin((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2/((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0,tpi],[0,tpi]]
        
    if name == 'III.10.19' or name == 91:
        symbol = mu, Bx, By, Bz = symbols('mu B_x B_y B_z')
        expr = mu*sqrt(Bx**2+By**2+Bz**2)
        f = lambda x: x[:,[0]]*torch.sqrt(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2)
        ranges = [[0,1],[0,1],[0,1],[0,1]]
        
    if name == 'III.12.43' or name == 92:
        symbol = n, hbar = symbols('n hbar')
        expr = n * hbar
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [[0,1],[0,1]]
        
    if name == 'III.13.18' or name == 93:
        symbol = E, d, k, hbar = symbols('E d k hbar')
        expr = 2*E*d**2*k/hbar
        f = lambda x: 2*x[:,[0]]*x[:,[1]]**2*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'III.14.14' or name == 94:
        symbol = I0, q, Ve, kb, T = symbols('I_0 q V_e k_b T')
        expr = I0 * (exp(q*Ve/(kb*T))-1)
        f = lambda x: x[:,[0]]*(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))-1)
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.15.12' or name == 95:
        symbol = U, k, d = symbols('U k d')
        expr = 2*U*(1-cos(k*d))
        f = lambda x: 2*x[:,[0]]*(1-torch.cos(x[:,[1]]*x[:,[2]]))
        ranges = [[0,1],[0,2*tpi],[0,1]]
        
    if name == 'III.15.14' or name == 96:
        symbol = hbar, E, d = symbols('hbar E d')
        expr = hbar**2/(2*E*d**2)
        f = lambda x: x[:,[0]]**2/(2*x[:,[1]]*x[:,[2]]**2)
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.15.27' or name == 97:
        symbol = alpha, n, d = symbols('alpha n d')
        expr = 2*pi*alpha/(n*d)
        f = lambda x: 2*tpi*x[:,[0]]/(x[:,[1]]*x[:,[2]])
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.17.37' or name == 98:
        symbol = beta, alpha, theta = symbols('beta alpha theta')
        expr = beta * (1+alpha*cos(theta))
        f = lambda x: x[:,[0]]*(1+x[:,[1]]*torch.cos(x[:,[2]]))
        ranges = [[0,1],[0,1],[0,2*tpi]]
        
    if name == 'III.19.51' or name == 99:
        symbol = m, q, eps, hbar, n = symbols('m q epsilon hbar n')
        expr = - m * q**4/(2*(4*pi*eps)**2*hbar**2)*1/n**2
        f = lambda x: - x[:,[0]]*x[:,[1]]**4/(2*(4*tpi*x[:,[2]])**2*x[:,[3]]**2)*1/x[:,[4]]**2
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'III.21.20' or name == 100:
        symbol = rho, q, A, m = symbols('rho q A m')
        expr = - rho*q*A/m
        f = lambda x: - x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'Rutherforld scattering' or name == 101:
        symbol = Z1, Z2, alpha, hbar, c, E, theta = symbols('Z_1 Z_2 alpha hbar c E theta')
        expr = (Z1*Z2*alpha*hbar*c/(4*E*sin(theta/2)**2))**2
        f = lambda x: (x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]*x[:,[4]]/(4*x[:,[5]]*torch.sin(x[:,[6]]/2)**2))**2
        ranges = [[0,1],[0,1],[0,1],[0,1],[0,1],[0.5,2],[0.1*tpi,0.9*tpi]]
        
    if name == 'Friedman equation' or name == 102:
        symbol = G, rho, kf, c, af = symbols('G rho k_f c a_f')
        expr = sqrt(8*pi*G/3*rho-kf*c**2/af**2)
        f = lambda x: torch.sqrt(8*tpi*x[:,[0]]/3*x[:,[1]] - x[:,[2]]*x[:,[3]]**2/x[:,[4]]**2)
        ranges = [[1,2],[1,2],[0,1],[0,1],[1,2]]
        
    if name == 'Compton scattering' or name == 103:
        symbol = E, m, c, theta = symbols('E m c theta')
        expr = E/(1+E/(m*c**2)*(1-cos(theta)))
        f = lambda x: x[:,[0]]/(1+x[:,[0]]/(x[:,[1]]*x[:,[2]]**2)*(1-torch.cos(x[:,[3]])))
        ranges = [[0,1],[0.5,2],[0.5,2],[0,2*tpi]]
        
    if name == 'Radiated gravitational wave power' or name == 104:
        symbol = G, c, m1, m2, r = symbols('G c m_1 m_2 r')
        expr = -32/5*G**4/c**5*(m1*m2)**2*(m1+m2)/r**5
        f = lambda x: -32/5*x[:,[0]]**4/x[:,[1]]**5*(x[:,[2]]*x[:,[3]])**2*(x[:,[2]]+x[:,[3]])/x[:,[4]]**5
        ranges = [[0,1],[0.5,2],[0,1],[0,1],[0.5,2]]
        
    if name == 'Relativistic aberration' or name == 105:
        symbol = theta2, v, c = symbols('theta_2 v c')
        expr = acos((cos(theta2)-v/c)/(1-v/c*cos(theta2)))
        f = lambda x: torch.arccos((torch.cos(x[:,[0]])-x[:,[1]]/x[:,[2]])/(1-x[:,[1]]/x[:,[2]]*torch.cos(x[:,[0]])))
        ranges = [[0,tpi],[0,1],[1,2]]
        
    if name == 'N-slit diffraction' or name == 106:
        symbol = I0, alpha, delta, N = symbols('I_0 alpha delta N')
        expr = I0 * (sin(alpha/2)/(alpha/2)*sin(N*delta/2)/sin(delta/2))**2
        f = lambda x: x[:,[0]] * (torch.sin(x[:,[1]]/2)/(x[:,[1]]/2)*torch.sin(x[:,[3]]*x[:,[2]]/2)/torch.sin(x[:,[2]]/2))**2
        ranges = [[0,1],[0.1*tpi,0.9*tpi],[0.1*tpi,0.9*tpi],[0.5,1]]
        
    if name == 'Goldstein 3.16' or name == 107:
        symbol = m, E, U, L, r = symbols('m E U L r')
        expr = sqrt(2/m*(E-U-L**2/(2*m*r**2)))
        f = lambda x: torch.sqrt(2/x[:,[0]]*(x[:,[1]]-x[:,[2]]-x[:,[3]]**2/(2*x[:,[0]]*x[:,[4]]**2)))
        ranges = [[1,2],[2,3],[0,1],[0,1],[1,2]]
        
    if name == 'Goldstein 3.55' or name == 108:
        symbol = m, kG, L, E, theta1, theta2 = symbols('m k_G L E theta_1 theta_2')
        expr = m*kG/L**2*(1+sqrt(1+2*E*L**2/(m*kG**2))*cos(theta1-theta2))
        f = lambda x: x[:,[0]]*x[:,[1]]/x[:,[2]]**2*(1+torch.sqrt(1+2*x[:,[3]]*x[:,[2]]**2/(x[:,[0]]*x[:,[1]]**2))*torch.cos(x[:,[4]]-x[:,[5]]))
        ranges = [[0.5,2],[0.5,2],[0.5,2],[0,1],[0,2*tpi],[0,2*tpi]]
        
    if name == 'Goldstein 3.64 (ellipse)' or name == 109:
        symbol = d, alpha, theta1, theta2 = symbols('d alpha theta_1 theta_2')
        expr = d*(1-alpha**2)/(1+alpha*cos(theta2-theta1))
        f = lambda x: x[:,[0]]*(1-x[:,[1]]**2)/(1+x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]]))
        ranges = [[0,1],[0,0.9],[0,2*tpi],[0,2*tpi]]
        
    if name == 'Goldstein 3.74 (Kepler)' or name == 110:
        symbol = d, G, m1, m2 = symbols('d G m_1 m_2')
        expr = 2*pi*d**(3/2)/sqrt(G*(m1+m2))
        f = lambda x: 2*tpi*x[:,[0]]**(3/2)/torch.sqrt(x[:,[1]]*(x[:,[2]]+x[:,[3]]))
        ranges = [[0,1],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'Goldstein 3.99' or name == 111:
        symbol = eps, E, L, m, Z1, Z2, q = symbols('epsilon E L m Z_1 Z_2 q')
        expr = sqrt(1+2*eps**2*E*L**2/(m*(Z1*Z2*q**2)**2))
        f = lambda x: torch.sqrt(1+2*x[:,[0]]**2*x[:,[1]]*x[:,[2]]**2/(x[:,[3]]*(x[:,[4]]*x[:,[5]]*x[:,[6]]**2)**2))
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'Goldstein 8.56' or name == 112:
        symbol = p, q, A, c, m, Ve = symbols('p q A c m V_e')
        expr = sqrt((p-q*A)**2*c**2+m**2*c**4) + q*Ve
        f = lambda x: torch.sqrt((x[:,[0]]-x[:,[1]]*x[:,[2]])**2*x[:,[3]]**2+x[:,[4]]**2*x[:,[3]]**4) + x[:,[1]]*x[:,[5]]
        ranges = [0,1]
        
    if name == 'Goldstein 12.80' or name == 113:
        symbol = m, p, omega, x, alpha, y = symbols('m p omega x alpha y')
        expr = 1/(2*m)*(p**2+m**2*omega**2*x**2*(1+alpha*y/x))
        f = lambda x: 1/(2*x[:,[0]]) * (x[:,[1]]**2+x[:,[0]]**2*x[:,[2]]**2*x[:,[3]]**2*(1+x[:,[4]]*x[:,[3]]/x[:,[5]]))
        ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'Jackson 2.11' or name == 114:
        symbol = q, eps, y, Ve, d = symbols('q epsilon y V_e d')
        expr = q/(4*pi*eps*y**2)*(4*pi*eps*Ve*d-q*d*y**3/(y**2-d**2)**2)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,x[:,[2]]]**2)*(4*tpi*x[:,[1]]*x[:,[3]]*x[:,[4]]-x[:,[0]]*x[:,[4]]*x[:,[2]]**3/(x[:,[2]]**2-x[:,[4]]**2)**2)
        ranges = [[0,1],[0.5,2],[1,2],[0,1],[0,1]]
        
    if name == 'Jackson 3.45' or name == 115:
        symbol = q, r, d, alpha = symbols('q r d alpha')
        expr = q/sqrt(r**2+d**2-2*d*r*cos(alpha))
        f = lambda x: x[:,[0]]/torch.sqrt(x[:,[1]]**2+x[:,[2]]**2-2*x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]]))
        ranges = [[0,1],[0,1],[0,1],[0,2*tpi]]
        
    if name == 'Jackson 4.60' or name == 116:
        symbol = Ef, theta, alpha, d, r = symbols('E_f theta alpha d r')
        expr = Ef * cos(theta) * ((alpha-1)/(alpha+2) * d**3/r**2 - r)
        f = lambda x: x[:,[0]] * torch.cos(x[:,[1]]) * ((x[:,[2]]-1)/(x[:,[2]]+2) * x[:,[3]]**3/x[:,[4]]**2 - x[:,[4]])
        ranges = [[0,1],[0,2*tpi],[0,2],[0,1],[0.5,2]]
        
    if name == 'Jackson 11.38 (Doppler)' or name == 117:
        symbol = omega, v, c, theta = symbols('omega v c theta')
        expr = sqrt(1-v**2/c**2)/(1+v/c*cos(theta))*omega
        f = lambda x: torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)/(1+x[:,[1]]/x[:,[2]]*torch.cos(x[:,[3]]))*x[:,[0]]
        ranges = [[0,1],[0,1],[1,2],[0,2*tpi]]
        
    if name == 'Weinberg 15.2.1' or name == 118:
        symbol = G, c, kf, af, H = symbols('G c k_f a_f H')
        expr = 3/(8*pi*G)*(c**2*kf/af**2+H**2)
        f = lambda x: 3/(8*tpi*x[:,[0]])*(x[:,[1]]**2*x[:,[2]]/x[:,[3]]**2+x[:,[4]]**2)
        ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1]]
        
    if name == 'Weinberg 15.2.2' or name == 119:
        symbol = G, c, kf, af, H, alpha = symbols('G c k_f a_f H alpha')
        expr = -1/(8*pi*G)*(c**4*kf/af**2+c**2*H**2*(1-2*alpha))
        f = lambda x: -1/(8*tpi*x[:,[0]])*(x[:,[1]]**4*x[:,[2]]/x[:,[3]]**2 + x[:,[1]]**2*x[:,[4]]**2*(1-2*x[:,[5]]))
        ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1],[0,1]]
        
    if name == 'Schwarz 13.132 (Klein-Nishina)' or name == 120:
        symbol = alpha, hbar, m, c, omega0, omega, theta = symbols('alpha hbar m c omega_0 omega theta')
        expr = pi*alpha**2*hbar**2/m**2/c**2*(omega0/omega)**2*(omega0/omega+omega/omega0-sin(theta)**2)
        f = lambda x: tpi*x[:,[0]]**2*x[:,[1]]**2/x[:,[2]]**2/x[:,[3]]**2*(x[:,[4]]/x[:,[5]])**2*(x[:,[4]]/x[:,[5]]+x[:,[5]]/x[:,[4]]-torch.sin(x[:,[6]])**2)
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2],[0,2*tpi]]
        
    return symbol, expr, f, ranges

四、总结与思考

        KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。

        1. KAN网络架构

  • 关键设计可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。

  • 示例结构输入层 → 隐藏层:每个输入节点通过单变量函数\phi_{q,i} \left( x_{i} \right) 连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数\psi_{q}组合得到输出。

        2. 优势与特点

  • 高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。

  • 可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。

  • 灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。

        3. 挑战与局限

  • 计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。

  • 泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。

  • 训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。

        4. 应用场景

  • 科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。

  • 可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。

  • 符号回归:从数据中自动发现数学公式(如物理定律)。

        5. 与传统MLP的对比

        6. 研究进展

  • 近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。

  • 开源实现:已有PyTorch等框架的初步实现。


【作者声明】

        本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。


 【关注我们】

        如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!

猜你喜欢

转载自blog.csdn.net/2303_77200324/article/details/146352467
今日推荐