-
建模中经常使用线性最小二乘法,实际上就是求超定线性方程组(未知数少,方程个数多)的最小二乘解,前面已经使用pinv()求超定线性方程组的最小二乘解.下面再举两个求最小二乘解的例子,并使用**numpy.linalg模块的lstsq()**函数求解.
-
先要明确这个函数的原义是用来求超定线性方程组的:
例如下面的方程组:
系数矩阵的第一列相当于给定了x的观测值 X=[0,1,2,3].transpose
右边的结果矩阵相当于给定了y的观测值 Y=[-1,0.2,0.9,2.1].transpose
然后使用两个观测值来拟合经验函数 y=mx+c
系数矩阵的第二列存在的意义有点类似于机器学习中的偏置θ0,用于和C相乘,注意这是必要的,在只给定观测值的情况下,我们也常常需要np.ones_like(X的长度来构建有这一“无效列”的矩阵.
- **lstsq(a,b,rcond=“warn”)**函数的参数详解(下面的矩阵都是array_like(类数组对象)):
1. a是一个M行N列的系数矩阵,前面说过需要构造np.ones_like(M)
2. b是一个(M,)或者(M,K),如果b是一个M行K列的二维矩阵,函数会逐个计算每一列的最小二乘法
3. rcond这个参数是可选的,是用于奇异矩阵的处理的,感兴趣的可以自行查看源码,官方推荐我们一般用 rcond=None
返回值:以下提到的所有矩阵都是ndarray, NumPy 最重要的一个特点是其 N 维数组对象 ndarray,它是一系列同类型数据的集合,以 0 下标为开始进行集合中元素的索引):
-
x : {(N,), (N, K)} ndarray (我们所要的结果,如果前面的b是二维的,那么这里也会有k列的a和b结果)
-
residuals : {(1,), (K,), (0,)} ndarray
-
rank: int
-
a 的奇异值
返回值重点关注返回集合中的x就行,所以我们一般的用法是lstsq()[0]
官方的使用栗子:
Examples -------- Fit a line, ``y = mx + c``, through some noisy data-points: >>> x = np.array([0, 1, 2, 3]) >>> y = np.array([-1, 0.2, 0.9, 2.1]) By examining the coefficients, we see that the line should have a gradient of roughly 1 and cut the y-axis at, more or less, -1. We can rewrite the line equation as ``y = Ap``, where ``A = [[x 1]]`` and ``p = [[m], [c]]``. Now use `lstsq` to solve for `p`: >>> A = np.vstack([x, np.ones(len(x))]).T >>> A array([[ 0., 1.], [ 1., 1.], [ 2., 1.], [ 3., 1.]]) >>> m, c = np.linalg.lstsq(A, y, rcond=None)[0] >>> print(m, c) 1.0 -0.95 Plot the data along with the fitted line: >>> import matplotlib.pyplot as plt >>> plt.plot(x, y, 'o', label='Original data', markersize=10) >>> plt.plot(x, m*x + c, 'r', label='Fitted line') >>> plt.legend() >>> plt.show()
-
下面举个栗子:
给定一组实验数据0 27. 1 26.8 2 26.5 3 26.3 4 26.1 5 25.7 6 25.3 24.8 我们来进行一元线性拟合 y=at+b
import numpy as np import numpy.linalg as LA import matplotlib.pyplot as plt t=np.arange(8) y=np.array([27.0,26.8,26.5,26.3,26.1,25.7,25.3,24.8]) A=np.c_[t, np.ones_like(t)] print(np.ones_like(t)) ab=LA.lstsq(A,y,rcond=None)[0] print(ab); plt.rc('font',size=16) plt.plot(t,y,'o',label='Original data',markersize=5) plt.plot(t,A.dot(ab),'r',label="Fitted line") plt.legend(); plt.show();
-
**感兴趣的可以自己运行试试看**