1、图示
import matplotlib.pyplot as mp, numpy as np
# 原函数
origin = lambda x: 2 * x - x ** 2
x = np.linspace(0, 2, 9999)
mp.plot(x, origin(x), c='black') # 可视化
# 原函数的导数
derivative = lambda x: 2 - 2 * x
# 梯度上升求
extreme_point = 0 # 初始值
alpha = 0.1 # 步长,即学习速率
presision = 0.001 # 允许误差范围
while True:
mp.scatter(extreme_point, origin(extreme_point)) # 可视化
error = alpha * derivative(extreme_point) # 爬坡步伐
extreme_point += error # 爬坡
if abs(error) < presision:
break # 误差较小时退出迭代
mp.show()
2、原理
- 求函数的极值点
-
-
求导数:
计算机使用迭代的方法,类似爬坡,一步一步逼近极值:
python代码表示(alpha为步长、derivative为导数):
derivative = lambda x: 2 - 2 * x # (2-x**2)的导数
extreme_point = 0 # 初始值
alpha = 0.1 # 步长,即学习速率
presision = 0.0001 # 允许误差范围
while True:
error = alpha * derivative(extreme_point) # 爬坡步伐
extreme_point += error # 爬坡
print(extreme_point)
if abs(error) < presision:
break # 误差较小时退出迭代
3、改进学习速率,先快后慢
- 示例:对 sin(x) 函数,随机选一点,求该点附近的极值点
import numpy as np
derivative = lambda x: np.cos(x) # sin(x)的导数
extreme_point = np.random.randint(0, 100) # 初始化极值点
print(extreme_point)
for i in range(1, 99999):
alpha = 1 / i # 步长(先快后慢)
extreme_point += alpha * derivative(extreme_point)
print(extreme_point)
print(extreme_point % (np.pi * 2))
- 打印结果(之一):
-
5
7.852861107256883
1.569675800077297 - 意思表示为:
- 附近存在极值点 (7.852861107256883)
4、附录
知识进阶
翻译
- alpha
- 希腊字母的第一个字母:
- presision
- 精度
- extreme point
- 极值点
- derivative
- 导数
- partial derivatives
- 偏导
- gradient ascent
- 梯度上升