MATLAB梯度下降法

梯度下降法

  梯度下降法(英语:Gradient descent)是一个一阶最优化算法。 要使用梯度下降法找到一个函数的局部极小值,必须向函数上当前点对应梯度(或者是近似梯度)的反方向的规定步长距离点进行迭代搜索。如果相反地向梯度正方向迭代进行搜索,则会接近函数的局部极大值点;这个过程则被称为梯度上升法。

梯度

  梯度的矢量,表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向变化最快,变化率最大(为该梯度的模)。

  • 二元函数   z = f ( x , y ) \ z=f(x,y) 有连续的一阶导数,则其梯度为:
    g r a d f ( x , y ) = f x i + f y j = ( f x , f y ) grad f(x,y) = \frac{∂ f}{∂ x}\vec{i} + \frac{∂ f}{∂ y}\vec{j}=(\frac{∂ f}{∂ x}, \frac{∂ f}{∂ y})
    函数在一点沿梯度方向的变化率最大,最大值为该梯度的模。
  • 三元函数   u = f ( x , y , z ) \ u=f(x,y,z) 有连续的一阶导数,则其梯度为:
    g r a d f ( x , y , z ) = f x i + f y j + f z k = ( f x , f y , f z ) grad f(x,y,z) = \frac{∂ f}{∂ x}\vec{i} + \frac{∂ f}{∂ y}\vec{j}+\frac{∂ f}{∂ z}\vec{k}=(\frac{∂ f}{∂ x}, \frac{∂ f}{∂ y},\frac{∂ f}{∂ z})
      则有迭代过程:
    x i + 1 = x i + μ f ( x i , y i , z i ) x y i + 1 = y i + μ f ( x i , y i , z i ) y z i + 1 = z i + μ f ( x i , y i , z i ) z x^{i+1} = x^{i}+\mu \frac{∂ f(x^{i},y^{i},z^{i})}{∂ x}, y^{i+1} = y^{i}+\mu \frac{∂ f(x^{i},y^{i},z^{i})}{∂ y}, z^{i+1} = z^{i}+\mu \frac{∂ f(x^{i},y^{i},z^{i})}{∂ z}

同样,该梯度方向与取得最大方向导数的方向一致,而它的模为方向导数的最大值。

梯度下降法

  梯度下降法是最早最简单,也是最为常用的最优化方法。梯度下降法实现简单,当目标函数是凸函数时,梯度下降法的解是全局解。一般情况下,其解不保证是全局最优解。梯度下降法的优化思想是用当前位置负梯度方向作为搜索方向,因为该方向为当前位置的最快下降方向,所以也被称为是”最速下降法“。最速下降法越接近目标值,步长越小,前进越慢。

  在机器学习中,基于基本的梯度下降法发展了两种梯度下降方法,分别为随机梯度下降法和批量梯度下降法。

  这里取一个简单的例子

%% 梯度下降法
%牛顿迭代法
function [x,y,n,point] = Tidu(fun,dfunx,dfuny,x,y,EPS,p)
a = feval(fun,x,y);
b = a+1;
n=1
point(n,:) = [x y a];
while (abs(a-b) >= EPS) 
  a = feval(fun,x,y);
  x = x - p*(feval(dfunx,x,y));
  y = y - p*(feval(dfuny,x,y));
  b = feval(fun,x,y); 
  n = n+1;
  point(n,:) = [x y b]; 
end

调用函数:

% 目标函数为 z=f(x,y)=(x^2+y^2)/2
clear all
clc
fun = inline('(x^2+y^2)/2','x','y');
dfunx = inline('x','x','y');
dfuny = inline('y','x','y'); 
x0 = 2;
y0 = 2;
EPS = 0.00001;
p = 0.5;
[x,y,n,point] = Tidu(fun,dfunx,dfuny,x0,y0,EPS,p)
figure
x = -0.1:0.1:2;
y = x;
[x,y] = meshgrid(x,y);
z = (x.^2+y.^2)/2;
surf(x,y,z)    %绘制三维表面图形
% hold on
% plot3(point(:,1),point(:,2),point(:,3),'linewidth',1,'color','black')
hold on
scatter3(point(:,1),point(:,2),point(:,3),'r','*');

迭代图:(从(2,2,4)收敛到(0,0,0))
迭代过程

发布了16 篇原创文章 · 获赞 1 · 访问量 891

猜你喜欢

转载自blog.csdn.net/qq_33866593/article/details/104763433