190121 矩阵除法中的陷阱(left division vs inv in matlab and python)

版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/qq_33039859/article/details/86585458

结论

A\binv速度更快,误差更低!

Reference:

  1. NumPy for MATLAB users
  2. what is the difference between inv() and \ (the backslash) ?
  3. inv() vs \
  4. Left Matrix Division and Numpy Solve

If A is an n x n matrix and B is n x m, solving A\B is tantamount to solving mn equations in mn unknowns. Finding the inverse of A is equivalent to finding A\eye(n), and hence is similar to solving nn equations in nn unknowns. If the number of columns, m, in B is less than n, it therefore takes less time to solve mn equations than doing inv(A)B which would involve nn equations combined with a matrix multiplication.
If A is n x p and not square with p < n, solving A\B requires solving m
n equations with only m*p unknowns and is overdetermined, so A\B will simply find the best least squares approximation to a solution, which makes it different from ‘inv’ which will produce an error.
On the other hand if p > n the number of unknowns exceeds the number of equations and the system is underdetermined. Hence A\B will assign some of the unknowns arbitrary values. In this it also differs from the ‘inv’ function which will again give an error.

反斜杠计算方法速度更快,而且残差减少了几个数量级。err_inv 和 err_bs 均为 1e-6 的阶数这个事实直接反映了矩阵的条件数。

此示例的行为非常常见。使用 A\b(而非 inv(A)*b)的速度要快两至三倍,并且会基于计算机准确度生成残差(相对于数据量值而言)

实验

import numpy as np
from scipy.linalg import orth,inv,norm,solve
import time

#%% using inv
n = 50
Q = orth(np.random.randn(n,n))
d = np.logspace(0,-10,n)
A = Q.dot(np.diag(d)).dot(Q.T)
x = np.random.randn(n,1)
b = np.dot(A,x)

ts = time.time()
y = inv(A).dot(b)
te = time.time()
print('Time cost {:f}'.format(te-ts))

err_inv = norm((y-x),2)
res_inv = norm((A.dot(y)-b),2)

print('err_inv: ',err_inv)
print('res_inv: ',res_inv)




#%% using matrix.I
print('\n')
ts = time.time()
w = np.mat(A).I.dot(b)
te =  time.time()

print('Time cost {:f}'.format(te-ts))

err_inv = norm((w-x),2)
res_inv = norm((A.dot(w)-b),2)


print('err_inv: ',err_inv)
print('res_inv: ',res_inv)


#%% using backslash
print('\n')
ts = time.time()
z = solve(A,b)
te = time.time()
print('Time cost {:f}'.format(te-ts))

err_inv = norm((z-x),2)
res_inv = norm((A.dot(z)-b),2)


print('err_inv: ',err_inv)
print('res_inv: ',res_inv)

#%% using np.linalg.lstsq(A,b,rcond=None)
print('\n')
ts = time.time()
u, resid, rank, s = np.linalg.lstsq(A,b,rcond=None)
te = time.time()
print('Time cost {:f}'.format(te-ts))

err_inv = norm((u-x),2)
res_inv = norm((A.dot(u)-b),2)


print('err_inv: ',err_inv)
print('res_inv: ',res_inv)

结果

Time cost 0.001216
err_inv:  4.4275202788253137e-07
res_inv:  5.3371879782231606e-08

Time cost 0.000187
err_inv:  4.4275202788253137e-07
res_inv:  5.3371879782231606e-08

Time cost 0.000783
err_inv:  3.076198269358094e-07
res_inv:  3.265153280995946e-16

Time cost 0.001050
err_inv:  1.2900433256244122e-06
res_inv:  2.081483125776792e-15

以下内容与本文主要内容无关:

  • Matlab中的表达式
b = (C*eye(size(c',1))+c'  *  c) \ ( c' ); # Good
b = inv(C*eye(size(c',1))+c'*c)*(c'); # Bad
  • Python中的表达式
from scipy.linalg import inv
# 表达式-1 <Bad>
beta = inv(np.dot(T3.T,T3) + np.eye(T3.T.shape[0]) * C).dot(T3.T)
# 表达式-2 <Bad>
beta = np.mat(np.dot(T3.T,T3) + np.eye(T3.T.shape[0]) * C).I.dot(T3.T)
# 表达式-3 <Good>
AA = np.mat(np.dot(T3.T, T3) + np.eye(T3.T.shape[0]) * C)
bb = T3.T
beta = np.linalg.solve(AA, bb)

# 表达式-3 <Good>
AA = np.mat(np.dot(T3.T, T3) + np.eye(T3.T.shape[0]) * C)
bb = T3.T
beta, resid, rank, s = np.linalg.lstsq(AA, bb)

猜你喜欢

转载自blog.csdn.net/qq_33039859/article/details/86585458