2025年12月28日

PPO(近端策略优化)训练流程总结

PPO 训练核心是“采样-计算-多轮裁剪优化”的循环,以“旧策略采样、新策略裁剪更新”为核心逻辑,兼顾稳定性、样本效率和实现简洁性,完整流程如下:

一、前置初始化

  1. 网络构建:搭建 Actor-Critic 双分支网络
  • Actor(策略网络):输入状态 (s),输出动作概率分布(离散动作输出类别 logits,连续动作输出高斯分布均值+标准差),用于生成动作。
  • Critic(价值网络):输入状态 (s),输出状态价值 (V(s))(预测该状态的长期期望奖励),用于计算优势值。
  • 可选共享底层特征层(如 MLP/CNN),减少参数冗余。
  1. 超参数设置
  • 折扣因子 (\gamma = 0.99)、GAE 参数 (\lambda = 0.95)、裁剪系数 (\epsilon = 0.2)
  • 多轮优化次数 (K = 10)、单智能体采样步数 (T = 2048)、并行智能体数 (N = 8)
  • 小批量大小 (batch_size = 64)、优化器(Adam,学习率 (lr = 3e-4))
  1. 环境与参数初始化:创建训练环境(如 CartPole/MuJoCo),初始化网络参数 (\theta)。

二、核心迭代循环(每轮 = 1 次策略更新)

循环执行至策略收敛(奖励达标/迭代上限),每轮包含 4 关键步骤:

步骤 1:并行采样数据(用旧策略收集样本)

  1. 启动 N 个并行智能体,每个智能体基于当前旧策略 (\pi_{\theta_{old}})(未更新的网络参数)与环境交互 (T) 步。
  2. 每步记录核心数据:状态 (s_t)、动作 (a_t)、即时奖励 (r_t)、旧策略动作对数概率 (\log\pi_{\theta_{old}}(a_t|s_t))、下一状态 (s_{t+1})、是否结束 (done_t)(回合结束为 True)。
  3. 最终收集 (N \times T) 条样本(如 (8 \times 2048 = 16384) 条),保证数据多样性。

步骤 2:计算优势值 \hat{A}_t 与目标价值 (V_{\text{targ}})

核心是用 GAE(广义优势估计)给每个动作“打分”,修正价值网络预测:

  1. 用 Critic 网络预测所有采样状态的价值: (V(s_t) = \text{Critic}(s_t))(V(s_{t+1}) = \text{Critic}(s_{t+1}))
  2. 计算时序差分误差 (\delta_t)
    \delta_t = r_t + \gamma \cdot V(s_{t+1}) \cdot (1 - done_t) - V(s_t)
  3. 从后往前倒推计算 GAE 优势值 (\hat{A}t)(平衡偏差与方差): \hat{A}_t = \delta_t + \gamma\lambda \cdot \hat{A}{t+1} \cdot (1 - done_t)
  4. 标准化优势值(提升优化稳定性):
    \hat{A}_t = \frac{\hat{A}_t - \text{mean}(\hat{A}_t)}{\text{std}(\hat{A}_t) + 1e-8}
  5. 计算目标价值 (V_{\text{targ}})(用于训练 Critic 网络):
    V_{\text{targ}} = \hat{A}_t + V(s_t)

步骤 3:保存旧策略与数据格式化

  1. 保存当前网络参数 (\theta)(\theta_{old}),作为后续计算“新旧策略概率比”的锚点,避免策略突变。
  2. 将所有样本((s_t, a_t, \log\pi_{\theta_{old}}(a_t|s_t), \hat{A}<em>t, V</em>{\text{targ}}))转换为张量,按小批量大小分组,生成数据加载器。

步骤 4:多轮小批量裁剪优化(核心!)

重复 (K) 轮迭代,充分利用采样数据更新网络,通过“裁剪”限制策略更新幅度:

  1. 遍历小批量样本:逐批计算损失并反向传播。
  2. 计算策略损失(Actor 优化)
  • 新策略动作对数概率: (\log\pi_{\theta}(a_t|s_t) = \text{Actor}(s_t, a_t))
  • 新旧策略概率比(用指数避免除法数值不稳定):
    r_t = \exp\left(\log\pi_{\theta}(a_t|s_t) - \log\pi_{\theta_{old}}(a_t|s_t)\right)
  • 裁剪前后的替代损失:
    surr1 = r_t \cdot \hat{A}_t, \quad surr2 = \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot \hat{A}_t
  • 策略损失(取最小值实现悲观估计,负号适配优化器最小化):
    L_{\text{policy}} = -\frac{1}{M} \sum_{i=1}^M \min(surr1, surr2)
  1. 计算价值损失(Critic 优化)
    用均方误差让 Critic 预测的 (V(s_t)) 逼近目标价值 (V_{\text{targ}})
    L_{\text{value}} = \text{MSELoss}\left(V(s_t), V_{\text{targ}}\right) = \frac{1}{M} \sum_{i=1}^M \left(V(s_t) - V_{\text{targ}}\right)^2
  2. 计算熵损失(鼓励探索)
    增加动作分布的熵,避免策略过早收敛到局部最优:
    L_{\text{entropy}} = -\beta \cdot \text{Entropy}\left(\pi_{\theta}(a|s_t)\right)
    (\beta = 0.01),熵越大,动作选择越随机)
  3. 总损失与参数更新
    平衡三类损失(权重系数 (c_1 = 0.5, c_2 = 0.01)):
    L_{\text{total}} = L_{\text{policy}} + c_1 \cdot L_{\text{value}} + c_2 \cdot L_{\text{entropy}}
    执行反向传播

三、终止条件

满足以下任一即停止训练:

  1. 策略性能达标(如 CartPole 奖励 ≥ 500、机器人稳定行走);
  2. 迭代次数达到上限(如 100 轮);
  3. 总损失不再下降(策略收敛)。
Share

You may also like...

发表评论

您的电子邮箱地址不会被公开。