运用scipy库进行线性拟合

一、概述

此代码运用 Python 的numpymatplotlibscipy库,完成对带噪声线性数据的线性模型拟合与可视化。具体步骤为定义线性模型函数,生成带有噪声的线性数据,利用curve_fit函数进行数据拟合,最后借助matplotlib库将原始数据与拟合直线进行可视化展示。

二、依赖库说明

  • numpy:用于数值计算,像生成数组、进行随机数生成等操作。
  • matplotlib.pyplot:用于数据可视化,可绘制散点图、折线图等。
  • scipy.optimize.curve_fit:用于曲线拟合,能找出与给定数据最佳匹配的模型参数。

三、代码详细解释

1. 导入必要的库

python

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

  • 导入numpy库并简称为np,便于后续使用numpy的函数和类。
  • 导入matplotlib.pyplot库并简称为plt,用于绘图。
  • scipy.optimize模块导入curve_fit函数,用于曲线拟合。

2. 定义模型函数

python

def linear_model(x, a, b):
    return a * x + b

  • 定义一个名为linear_model的函数,它代表一个线性模型。
  • 该函数接受三个参数:自变量x、斜率a和截距b
  • 函数返回值是线性方程a * x + b的计算结果。

3. 生成数据

python

x = np.linspace(0, 10, 100)
y = 2 * x + 1 + np.random.normal(0, 1, x.shape)

  • np.linspace(0, 10, 100):生成一个包含 100 个元素的一维数组x,元素范围从 0 到 10 且均匀分布。
  • 2 * x + 1 + np.random.normal(0, 1, x.shape):生成因变量y,它基于真实的线性关系2 * x + 1,并添加了均值为 0、标准差为 1 的高斯噪声。np.random.normal(0, 1, x.shape)用于生成与x相同形状的随机噪声数组。

4. 拟合数据

python

params, _ = curve_fit(linear_model, x, y)
a, b = params

  • curve_fit(linear_model, x, y):调用curve_fit函数对linear_model进行拟合,传入自变量x和因变量y
  • params:存储拟合得到的参数,即线性模型的斜率a和截距b
  • _:表示忽略curve_fit函数返回的协方差矩阵。
  • a, b = params:将拟合得到的参数解包,分别赋值给变量ab

5. 可视化结果

python

plt.scatter(x, y, label='Data')
plt.plot(x, linear_model(x, a, b), color='red', label='Fitted line')
plt.legend()
plt.show()

  • plt.scatter(x, y, label='Data'):绘制xy的散点图,标签为Data,用于展示原始数据点。
  • plt.plot(x, linear_model(x, a, b), color='red', label='Fitted line'):绘制拟合直线,使用拟合得到的参数ab计算直线上的点,颜色为红色,标签为Fitted line
  • plt.legend():显示图例,方便区分原始数据和拟合直线。
  • plt.show():显示绘制好的图形。

四、注意事项

  • 此代码假设数据符合线性关系,若数据呈现非线性特征,需重新定义模型函数。
  • 噪声的标准差可根据实际情况调整np.random.normal函数的参数。
  • 对于curve_fit函数的使用,若数据存在异常值,可能会影响拟合效果,可考虑对数据进行预处理或使用更鲁棒的拟合方法。

完整代码

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

# 定义模型函数
def linear_model(x, a, b):
    return a * x + b

# 生成数据
x = np.linspace(0, 10, 100)
y = 2 * x + 1 + np.random.normal(0, 1, x.shape)

# 拟合数据
params, _ = curve_fit(linear_model, x, y)
a, b = params

# 可视化结果
plt.scatter(x, y, label='Data')
plt.plot(x, linear_model(x, a, b), color='red', label='Fitted line')
plt.legend()
plt.show()