8. Actor-Critic、DDPG、A3C

Actor-Critic

上一章讲到的蒙特卡罗策略梯度REINFORCE算法更新需要完整的状态序列,而且是每个序列单独对策略$\theta$进行更新,不太容易收敛。因此本章讨论策略(Policy Based)和价值(Value Based)相结合的方法Actor-Critic算法,解耦成生成动作并与环境交互的Actor,和评估Actor的表现并指导Actor下一阶段动作的Critic,用期望值Q代替蒙特卡洛采样得到的G实现单步更新,增大了学习效率。

对于Actor-Critic算法,我们需要

Actor,策略函数的近似,参数$\theta$,

\[{\pi _\theta }\left( {s,a} \right) = P\left( {a|s,\theta } \right)\]

Critic,价值函数的近似,参数$w$,

\[\begin{array}{l}
\hat v(s,w) \approx {v_\pi }(s)\\
\hat q(s,a,w) \approx {q_\pi }(s,a)
\end{array}\]

我们用两套神经网络来代替,流程是Critic通过Q网络计算状态的最优价值$v_t$,而Actor利用$v_t$这个最优价值迭代更新策略函数的参数,进而选择动作,并得到反馈到新的状态,Critic使用反馈和新的状态更新Q网络参数$\theta$, 在后面Critic会使用新的网络参数$w$来帮Actor计算状态的最优价值。

Actor-Critic算法


1. 初始化Actor网络的策略函数参数$\theta$,Critic网络的价值函数参数$w$,

2. 对每一个episode:

3.  从$s_0$开始,对episode中的每一步:

a.    在Actor网络中输入状态$s$,输出动作$a$,

b.    采取动作$a$,得到新的状态$s'$和即时奖励$r$

c.    在Critic网络中分别输入状态$s$和$s'$,得到值函数输出$v(s)$和$v(s')$

d.    计算TD误差$\delta  = r + \gamma v(s') - v(s)$

e.    更新Critic网络参数$w$,通过均方差损失函数梯度更新

\[w \leftarrow w + \beta \delta \phi (s,a)\]

f.    更新Actor网络参数$\theta$,

\[\theta  \leftarrow \theta  + \alpha {\nabla _\theta }\log {\pi _\theta }\left( {{s_t},{a_t}} \right)\delta \]

4. 重复以上步骤,从许多个episode中的每一步中不断学习。


DDPG

Actor-Critic 涉及到了两个神经网络,而且每次都是在连续状态中更新参数,每次参数更新前后都存在相关性,导致神经网络只能片面的看待问题,甚至导致神经网络学不到东西。为了解决这个问题,和之前我们讲到的DQN类似,Google DeepMind引入经验回放和双网络的方法来改进Actor-Critic难收敛的问题,提出了Deep Deterministic Policy Gradient[1]。

 DDPG算法


1. 初始化Actor当前网络$Q^a$的参数$\theta$,Actor目标网络$Q'^a$的参数$\theta'$,Critic当前网络$Q^c$的参数$w$,Critic目标网络$Q'^c$的参数$w'$,空的经验回放的集合D

2. 对每一个episode:

3.  从$s_0$开始,对episode中的每一步:

A.    在Actor当前网络中输入状态$s$,得到动作$a = {\pi _\theta }\left( {\phi (s)} \right) + N$,

B.    执行动作$a$,得到新的状态$s'$和即时奖励$r$,是否终止状态$is\_end$

C.    将$\left\{ {\phi (s),a,r,\phi (s'),is\_end} \right\}$五元组存入经验回放集合D

D.    从经验回放集合D中采样m个样本$\left\{ {\phi ({s_j}),{a_j},{r_j},\phi (s{'_j}),is\_en{d_j}} \right\}$,$j = 1,2,...,m$来更新两个网络参数:

a).      根据Actor目标网络$Q'^a$,依据采样样本中下一状态$s'$的最优下一动作$a'$,

\[a' = {\pi _{\theta '}}\left( {\phi \left( {s'} \right)} \right)\]

b).      根据Critic目标网络$Q'^c$,依据$s'$,$a'$计算当前目标Q值

\[{y_t} = \left\{ \begin{array}{l}
{r_t}{\rm{ is\_en}}{{\rm{d}}_t}{\rm{ is true}}\\
{r_t} + \gamma Q'\left( {\phi \left( {s{'_t}} \right),a',w'} \right){\rm{ is\_en}}{{\rm{d}}_t}{\rm{ is false}}
\end{array} \right.\]

c).      计算TD-error$\delta$

\[\delta  = {y_t} - Q\left( {\phi \left( {{s_t}} \right),a,w} \right)\]

d).      根据TD-error$\delta$,通过均方差损失函数梯度更新Critic当前网络参数$w$,

\[w \leftarrow w + \beta \delta \phi (s,a)\]

e).      更新Actor当前网络参数$\theta$,

\[\theta  \leftarrow \theta  + \alpha {\nabla _\theta }\log {\pi _\theta }\left( {{s_t},{a_t}} \right)\delta \]

E.    如果t达到设定的目标网络参数更新频率,则更新Actor目标网络和Actor目标网络参数

\[\begin{array}{l}
w' \leftarrow \tau w + (1 - \tau )w'\\
\theta ' \leftarrow \tau \theta + (1 - \tau )\theta '
\end{array}\]


A3C

[1] Lillicrap T P, Hunt J J, Pritzel A, et al. Continuous control with deep reinforcement learning[J]. international conference on learning representations, 2016.
 
[2] Mnih V, Badia A P, Mirza M, et al. Asynchronous methods for deep reinforcement learning[J]. international conference on machine learning, 2016: 1928-1937.

猜你喜欢

转载自www.cnblogs.com/yijuncheng/p/10509691.html