PPO 训练核心是“采样-计算-多轮裁剪优化”的循环,以“旧策略采样、新策略裁剪更新”为核心逻辑,兼顾稳定性、样本效率和实现简洁性,完整流程如下:
一、前置初始化
- 网络构建:搭建 Actor-Critic 双分支网络
- Actor(策略网络):输入状态 (s),输出动作概率分布(离散动作输出类别 logits,连续动作输出高斯分布均值+标准差),用于生成动作。
- Critic(价值网络):输入状态 (s),输出状态价值 (V(s))(预测该状态的长期期望奖励),用于计算优势值。
- 可选共享底层特征层(如 MLP/CNN),减少参数冗余。
- 超参数设置:
- 折扣因子 (\gamma = 0.99)、GAE 参数 (\lambda = 0.95)、裁剪系数 (\epsilon = 0.2)
- 多轮优化次数 (K = 10)、单智能体采样步数 (T = 2048)、并行智能体数 (N = 8)
- 小批量大小 (batch_size = 64)、优化器(Adam,学习率 (lr = 3e-4))
- 环境与参数初始化:创建训练环境(如 CartPole/MuJoCo),初始化网络参数 (\theta)。
二、核心迭代循环(每轮 = 1 次策略更新)
循环执行至策略收敛(奖励达标/迭代上限),每轮包含 4 关键步骤:
步骤 1:并行采样数据(用旧策略收集样本)
- 启动 N 个并行智能体,每个智能体基于当前旧策略 (\pi_{\theta_{old}})(未更新的网络参数)与环境交互 (T) 步。
- 每步记录核心数据:状态 (s_t)、动作 (a_t)、即时奖励 (r_t)、旧策略动作对数概率 (\log\pi_{\theta_{old}}(a_t|s_t))、下一状态 (s_{t+1})、是否结束 (done_t)(回合结束为 True)。
- 最终收集 (N \times T) 条样本(如 (8 \times 2048 = 16384) 条),保证数据多样性。
步骤 2:计算优势值 \hat{A}_t 与目标价值 (V_{\text{targ}})
核心是用 GAE(广义优势估计)给每个动作“打分”,修正价值网络预测:
- 用 Critic 网络预测所有采样状态的价值: (V(s_t) = \text{Critic}(s_t))、(V(s_{t+1}) = \text{Critic}(s_{t+1}))
- 计算时序差分误差 (\delta_t):
\delta_t = r_t + \gamma \cdot V(s_{t+1}) \cdot (1 - done_t) - V(s_t) - 从后往前倒推计算 GAE 优势值 (\hat{A}t)(平衡偏差与方差): \hat{A}_t = \delta_t + \gamma\lambda \cdot \hat{A}{t+1} \cdot (1 - done_t)
- 标准化优势值(提升优化稳定性):
\hat{A}_t = \frac{\hat{A}_t - \text{mean}(\hat{A}_t)}{\text{std}(\hat{A}_t) + 1e-8} - 计算目标价值 (V_{\text{targ}})(用于训练 Critic 网络):
V_{\text{targ}} = \hat{A}_t + V(s_t)
步骤 3:保存旧策略与数据格式化
- 保存当前网络参数 (\theta) 为 (\theta_{old}),作为后续计算“新旧策略概率比”的锚点,避免策略突变。
- 将所有样本((s_t, a_t, \log\pi_{\theta_{old}}(a_t|s_t), \hat{A}<em>t, V</em>{\text{targ}}))转换为张量,按小批量大小分组,生成数据加载器。
步骤 4:多轮小批量裁剪优化(核心!)
重复 (K) 轮迭代,充分利用采样数据更新网络,通过“裁剪”限制策略更新幅度:
- 遍历小批量样本:逐批计算损失并反向传播。
- 计算策略损失(Actor 优化):
- 新策略动作对数概率: (\log\pi_{\theta}(a_t|s_t) = \text{Actor}(s_t, a_t))
- 新旧策略概率比(用指数避免除法数值不稳定):
r_t = \exp\left(\log\pi_{\theta}(a_t|s_t) - \log\pi_{\theta_{old}}(a_t|s_t)\right) - 裁剪前后的替代损失:
surr1 = r_t \cdot \hat{A}_t, \quad surr2 = \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot \hat{A}_t - 策略损失(取最小值实现悲观估计,负号适配优化器最小化):
L_{\text{policy}} = -\frac{1}{M} \sum_{i=1}^M \min(surr1, surr2)
- 计算价值损失(Critic 优化):
用均方误差让 Critic 预测的 (V(s_t)) 逼近目标价值 (V_{\text{targ}}):
L_{\text{value}} = \text{MSELoss}\left(V(s_t), V_{\text{targ}}\right) = \frac{1}{M} \sum_{i=1}^M \left(V(s_t) - V_{\text{targ}}\right)^2 - 计算熵损失(鼓励探索):
增加动作分布的熵,避免策略过早收敛到局部最优:
L_{\text{entropy}} = -\beta \cdot \text{Entropy}\left(\pi_{\theta}(a|s_t)\right)
( (\beta = 0.01),熵越大,动作选择越随机) - 总损失与参数更新:
平衡三类损失(权重系数 (c_1 = 0.5, c_2 = 0.01)):
L_{\text{total}} = L_{\text{policy}} + c_1 \cdot L_{\text{value}} + c_2 \cdot L_{\text{entropy}}
执行反向传播
三、终止条件
满足以下任一即停止训练:
- 策略性能达标(如 CartPole 奖励 ≥ 500、机器人稳定行走);
- 迭代次数达到上限(如 100 轮);
- 总损失不再下降(策略收敛)。