线性回归中使用梯度下降法
目的:求解损失函数的唯一最优解
1. 梯度下降法
- 不是一个机器学习算法,只是一种基于搜索的最优化方法。
- 作用:最小化损失函数。
- 梯度上升法:最大化效用函数。
- 注意:
- n称为学习率(learning rate)
- n的取值影响获得最优解的速度
- n取得不合适,可能得不到最优解
- n是梯度下降算法的一个超参数
- 问题:不是所有函数都有唯一的极值点,有些函数可能有多个,运行一次可能得到的是局部最优解,不是全局最优解。
- 解决:
- 运行多次,随机化初始点。
- 梯度下降法的初始点也是一个超参数。
2. 具体代码
import numpy as np
import matplotlib.pyplot as plt
plot_x=np.linspace(-1,6,141) #140个点,把1~6平均平均分成141端
plot_x
array([-1. , -0.95, -0.9 , -0.85, -0.8 , -0.75, -0.7 , -0.65, -0.6 ,
-0.55, -0.5 , -0.45, -0.4 , -0.35, -0.3 , -0.25, -0.2 , -0.15,
-0.1 , -0.05, 0. , 0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 ,
0.35, 0.4 , 0.45, 0.5 , 0.55, 0.6 , 0.65, 0.7 , 0.75,
0.8 , 0.85, 0.9 , 0.95, 1. , 1.05, 1.1 , 1.15, 1.2 ,
1.25, 1.3 , 1.35, 1.4 , 1.45, 1.5 , 1.55, 1.6 , 1.65,
1.7 , 1.75, 1.8 , 1.85, 1.9 , 1.95, 2. , 2.05, 2.1 ,
2.15, 2.2 , 2.25, 2.3 , 2.35, 2.4 , 2.45, 2.5 , 2.55,
2.6 , 2.65, 2.7 , 2.75, 2.8 , 2.85, 2.9 , 2.95, 3. ,
3.05, 3.1 , 3.15, 3.2 , 3.25, 3.3 , 3.35, 3.4 , 3.45,
3.5 , 3.55, 3.6 , 3.65, 3.7 , 3.75, 3.8 , 3.85, 3.9 ,
3.95, 4. , 4.05, 4.1 , 4.15, 4.2 , 4.25, 4.3 , 4.35,
4.4 , 4.45, 4.5 , 4.55, 4.6 , 4.65, 4.7 , 4.75, 4.8 ,
4.85, 4.9 , 4.95, 5. , 5.05, 5.1 , 5.15, 5.2 , 5.25,
5.3 , 5.35, 5.4 , 5.45, 5.5 , 5.55, 5.6 , 5.65, 5.7 ,
5.75, 5.8 , 5.85, 5.9 , 5.95, 6. ])
plot_y=(plot_x-2.5)**2-1 #损失函数
plt.plot(plot_x,plot_y)
[<matplotlib.lines.Line2D at 0x245cd5068c8>]
def dJ(theta):#对theta求导
return 2*(theta-2.5)
def J(theta):#求损失函数
return (theta-2.5)**2-1
eta=0.1#学习率
epsilon=1e-8 #0.00000001
theta=0.0#初始化为0
while True:
gradient=dJ(theta) #梯度
last_theta=theta #保存上一次求的theta
theta= theta - eta * gradient #梯度下降
if(abs(J(theta)-J(last_theta))<epsilon):
#如果两次theta取值的损失函数相差很小(<epsilon),则跳出循环
break
print(theta)
print(J(theta))
2.499891109642585
-0.99999998814289
theta=0.0#初始化为0
theta_history= [theta]
while True:
gradient=dJ(theta) #梯度
last_theta=theta #保存上一次求的theta
theta= theta - eta * gradient #梯度下降
theta_history.append(theta)
if(abs(J(theta)-J(last_theta))<epsilon):
#如果两次theta取值的损失函数相差很小(<epsilon),则跳出循环
break
plt.plot(plot_x,J(plot_x))
plt.plot(np.array(theta_history),J(np.array(theta_history)),color='r',marker='+')
[<matplotlib.lines.Line2D at 0x245cdb941c8>]
len(theta_history)
46
3. 将上面代码进行封装
def gradient_descent(initial_theta,eta,n_iters=1e4,epsilon=1e-8):
theta=initial_theta
theta_history.append(initial_theta)
i_iter=0 #当前运行次数
while i_iter<n_iters:#n_iters为运行总次数上限
gradient=dJ(theta) #梯度
last_theta=theta #保存上一次求的theta
theta= theta - eta * gradient #梯度下降
theta_history.append(theta)#记录每次的theta
if(abs(J(theta)-J(last_theta))<epsilon):
#如果两次theta取值的损失函数相差很小(<epsilon),则跳出循环
break
i_iter+=1
def plot_theta_history():
plt.plot(plot_x,J(plot_x))#原损失函数曲线
plt.plot(np.array(theta_history),J(np.array(theta_history)),color='r',marker='+')#theta取值的曲线变化
1. eta取0.01
eta=0.01
theta_history=[]
gradient_descent(0.,eta)
plot_theta_history()
len(theta_history)
424
2. eta取0.001(过小)
eta=0.001
theta_history=[]
gradient_descent(0.,eta)
plot_theta_history()
len(theta_history)
3682
3. eta取0.8(过大,但是可以)
eta=0.8
theta_history=[]
gradient_descent(0.,eta)
plot_theta_history()
4. eta取1.1(过大,取不到极小值)
eta=1.1
theta_history=[]
gradient_descent(0.,eta)
plot_theta_history() #eta取值太大,导致数组结果过大
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
<ipython-input-16-6409def1277e> in <module>
1 eta=1.1
2 theta_history=[]
----> 3 gradient_descent(0.,eta)
4 plot_theta_history()
<ipython-input-10-512a889f8a22> in gradient_descent(initial_theta, eta, epsilon)
9 theta= theta - eta * gradient #梯度下降
10 theta_history.append(theta)
---> 11 if(abs(J(theta)-J(last_theta))<epsilon):
12 #如果两次theta取值的损失函数相差很小(<epsilon),则跳出循环
13 break
<ipython-input-5-6a0c75abd131> in J(theta)
1 def J(theta):#求损失函数
----> 2 return (theta-2.5)**2-1
OverflowError: (34, 'Result too large')
def J(theta): #返回theta时进行一次异常检测-->过大不会报错
try:
return (theta-2.5)**2-1.
except:
return float('inf')# 返回正无穷
eta=1.1
theta_history=[]
gradient_descent(0.,eta)
len(theta_history)
10001
theta_history[-1]#数组太大——>在封装中增加参数n_iters设置循环上限次数
nan
eta=1.1
theta_history=[]
gradient_descent(0.,eta,n_iters=10)
plot_theta_history()