Python【梯度上升】简洁代码

1、图示

import matplotlib.pyplot as mp, numpy as np
# 原函数
origin = lambda x: 2 * x - x ** 2
x = np.linspace(0, 2, 9999)
mp.plot(x, origin(x), c='black')  # 可视化
# 原函数的导数
derivative = lambda x: 2 - 2 * x
# 梯度上升求
extreme_point = 0  # 初始值
alpha = 0.1  # 步长,即学习速率
presision = 0.001  # 允许误差范围
while True:
    mp.scatter(extreme_point, origin(extreme_point))  # 可视化
    error = alpha * derivative(extreme_point)  # 爬坡步伐
    extreme_point += error  # 爬坡
    if abs(error) < presision:
        break  # 误差较小时退出迭代
mp.show()

这里写图片描述

2、原理

求函数的极值点
f ( x ) = x 2 + 2 x
求导数:
f ( x ) = 2 x + 2

计算机使用迭代的方法,类似爬坡,一步一步逼近极值:
x i + 1 = x i + α f ( x i ) x i

python代码表示(alpha为步长、derivative为导数):
x = x + a l p h a d e r i v a t i v e ( x )
derivative = lambda x: 2 - 2 * x  # (2-x**2)的导数
extreme_point = 0  # 初始值
alpha = 0.1  # 步长,即学习速率
presision = 0.0001  # 允许误差范围
while True:
    error = alpha * derivative(extreme_point)  # 爬坡步伐
    extreme_point += error  # 爬坡
    print(extreme_point)
    if abs(error) < presision:
        break  # 误差较小时退出迭代

3、改进学习速率,先快后慢

  • 示例:对 sin(x) 函数,随机选一点,求该点附近的极值点
import numpy as np
derivative = lambda x: np.cos(x)  # sin(x)的导数
extreme_point = np.random.randint(0, 100)  # 初始化极值点
print(extreme_point)
for i in range(1, 99999):
    alpha = 1 / i  # 步长(先快后慢)
    extreme_point += alpha * derivative(extreme_point)
print(extreme_point)
print(extreme_point % (np.pi * 2))
打印结果(之一):
5
7.852861107256883
1.569675800077297
意思表示为:
x i = 5 附近存在极值点 x j = 2.5 π (7.852861107256883)

4、附录

知识进阶

翻译

alpha
希腊字母的第一个字母: α
presision
精度
extreme point
极值点
derivative
导数
partial derivatives
偏导
gradient ascent
梯度上升

猜你喜欢

转载自blog.csdn.net/Yellow_python/article/details/81414581