跳转至

近端策略优化(PPO)

PPO(Proximal Policy Optimization)是目前强化学习工业界最主流的算法。它是 OpenAI 提出的对策略梯度方法的改进,也是 ChatGPT 等大模型 RLHF 训练的核心底座。PPO 的核心思想只有一句话:更新策略时,步子不能迈得太大。


一、为什么需要 PPO?

1.1 策略梯度的不稳定性

标准策略梯度(如 REINFORCE、Actor-Critic)有一个严重问题:

一次更新可能导致策略发生剧烈变化 → 性能暴跌 → 难以恢复

灾难性更新

想象你在训练一个走路的机器人。某次梯度更新恰好过大,策略突然从"正常行走"变成了"原地打转"。然后在这个糟糕的策略下收集的数据全是低质量的 → 下次更新更差 → 恶性循环。

1.2 信任域方法(TRPO)

PPO 的前身 TRPO(Trust Region Policy Optimization)提出了解决方案:限制每次更新后新旧策略的差异不能太大

TRPO 用 KL 散度约束:

\[ \max_\theta \; \mathbb{E}\left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{\text{old}}}(a|s)} \hat{A}(s,a)\right] \quad \text{s.t.} \quad \mathbb{E}\left[D_{KL}(\pi_{\theta_{\text{old}}} \| \pi_\theta)\right] \leq \delta \]

但 TRPO 实现复杂(需要二阶优化、共轭梯度),不利于大规模应用。

1.3 PPO 的改进

PPO 的核心贡献:用一个简单的 Clip(截断)操作替代 TRPO 的 KL 约束,效果相当但实现极其简单。


二、PPO 的核心机制 ⭐

2.1 重要性采样比率

定义新旧策略的概率比:

\[ r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_{\text{old}}}(a_t \mid s_t)} \]
  • \(r_t = 1\):新旧策略一样
  • \(r_t > 1\):新策略更倾向于选这个动作
  • \(r_t < 1\):新策略更不倾向于选这个动作

2.2 PPO-Clip 目标函数

\[ L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \;\; \text{clip}(r_t(\theta), \; 1-\epsilon, \; 1+\epsilon) \cdot \hat{A}_t \right) \right] \]

其中 \(\epsilon\) 通常取 0.2(即允许概率比在 \([0.8, 1.2]\) 范围内变化)。

2.3 逐行拆解 Clip 机制

这个公式看起来复杂,但逻辑很清晰。分两种情况:

情况 1:优势 \(\hat{A}_t > 0\)(动作比平均好,要增大概率)

\[ L = \min\left(r_t \cdot \hat{A}_t, \;\; \min(r_t, 1+\epsilon) \cdot \hat{A}_t \right) = \min(r_t, 1+\epsilon) \cdot \hat{A}_t \]

→ 即使动作很好,概率比 \(r_t\) 最多只能是 \(1+\epsilon\)防止过度增大

情况 2:优势 \(\hat{A}_t < 0\)(动作不如平均,要减小概率)

\[ L = \min\left(r_t \cdot \hat{A}_t, \;\; \max(r_t, 1-\epsilon) \cdot \hat{A}_t \right) = \max(r_t, 1-\epsilon) \cdot \hat{A}_t \]

→ 即使动作很差,概率比 \(r_t\) 最少只能是 \(1-\epsilon\)防止过度减小

一句话理解 PPO-Clip

好动作可以多选,但别太贪;坏动作可以少选,但别太狠。 这就是"近端"(Proximal)的含义——留在旧策略附近。

2.4 图解 Clip 机制

当 A > 0 时(好动作):
     目标
      │          ╱ 无限制
      │        ╱
      │      ╱ ← 截断处 (1+ε)
      │    ╱-------- 平坦区(不再增长)
      │  ╱
      │╱
  ────┼────────────── r_t

当 A < 0 时(坏动作):
  ────┼────────────── r_t
      │╲
      │  ╲
      │    ╲--------- 平坦区(不再下降)
      │      ╲ ← 截断处 (1-ε)
      │        ╲
      │          ╲ 无限制
     目标

三、PPO 完整算法

3.1 总损失函数

PPO 的总损失由三部分组成:

\[ L(\theta) = L^{\text{CLIP}}(\theta) - c_1 \cdot L^{\text{VF}}(\theta) + c_2 \cdot S[\pi_\theta](s) \]
公式 作用
\(L^{\text{CLIP}}\) Clip 策略损失 更新 Actor
\(L^{\text{VF}}\) \((V_\theta(s) - V_{\text{target}})^2\) 更新 Critic(价值函数)
\(S[\pi_\theta]\) \(-\sum_a \pi(a \mid s) \log \pi(a \mid s)\) 熵正则化,鼓励探索
\(c_1\) 通常 0.5 价值损失权重
\(c_2\) 通常 0.01 熵奖励权重

熵正则化的作用

熵越大 → 策略越"均匀"、越随机 → 探索越充分。加入熵奖励可以防止策略过早收敛到次优的确定性策略。

3.2 算法伪代码

初始化策略网络 πθ(既是 Actor 又有 Critic 头)

循环 K 个迭代:
  ──── 数据采集阶段 ────
  1. 用当前策略 πθ_old 与环境交互 T 步,收集数据:
     {(sₜ, aₜ, rₜ, sₜ₊₁, log πθ_old(aₜ|sₜ))}

  2. 用 GAE 计算每个时间步的优势估计 Â_t 和回报 R̂_t

  ──── 策略优化阶段 ────
  3. 对收集的数据进行 M 个 epoch 的小批量更新:
     for epoch in range(M):       ← 通常 M = 3~10
       for batch in mini_batches:
         a. 计算概率比 rₜ = πθ(aₜ|sₜ) / πθ_old(aₜ|sₜ)
         b. 计算 Clip 损失
         c. 计算价值函数损失
         d. 计算熵奖励
         e. 反向传播更新 θ

  4. θ_old ← θ

数据重用是 PPO 的关键

收集一批数据后,PPO 可以对其进行多次(M 个 epoch)优化——这比 REINFORCE(一批数据只能用一次)高效得多。Clip 机制正是为了保证多次优化不会偏离太远。

3.3 核心代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

class PPOAgent(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(state_dim, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh()
        )
        self.actor = nn.Linear(64, action_dim)
        self.critic = nn.Linear(64, 1)

    def forward(self, x):
        features = self.shared(x)
        return Categorical(logits=self.actor(features)), self.critic(features)

    def get_action(self, state):
        dist, value = self.forward(state)
        action = dist.sample()
        return action, dist.log_prob(action), value

def ppo_update(agent, optimizer, states, actions, old_log_probs, 
               returns, advantages, clip_eps=0.2, epochs=4, batch_size=64):

    for _ in range(epochs):
        # 随机打乱数据,分成小批量
        indices = torch.randperm(len(states))
        for start in range(0, len(states), batch_size):
            idx = indices[start:start + batch_size]

            dist, values = agent(states[idx])
            new_log_probs = dist.log_prob(actions[idx])
            entropy = dist.entropy().mean()

            # 计算概率比
            ratio = (new_log_probs - old_log_probs[idx]).exp()

            # PPO-Clip 损失
            adv = advantages[idx]
            surr1 = ratio * adv
            surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv
            actor_loss = -torch.min(surr1, surr2).mean()

            # 价值函数损失
            critic_loss = (returns[idx] - values.squeeze()).pow(2).mean()

            # 总损失
            loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), 0.5)  # 梯度裁剪
            optimizer.step()

四、PPO 的两个变体

4.1 PPO-Clip(最常用)

就是上面介绍的截断版本,也是实践中几乎唯一使用的形式。

4.2 PPO-Penalty

不用 clip,而是直接在目标函数中加 KL 惩罚:

\[ L(\theta) = \mathbb{E}_t \left[ r_t(\theta) \hat{A}_t - \beta \cdot D_{KL}(\pi_{\theta_{\text{old}}} \| \pi_\theta) \right] \]

\(\beta\) 自适应调整:如果 KL 散度超过目标值,就增大 \(\beta\);反之减小。

实践中的选择

PPO-Clip 实现更简单、效果更稳定,是绝对主流。PPO-Penalty 主要存在于论文中。


五、PPO 超参数调优指南

超参数 推荐值 说明
\(\epsilon\)(clip 范围) 0.1 ~ 0.3 常用 0.2。越小越保守
\(\gamma\)(折扣因子) 0.99 看重长远回报
\(\lambda\)(GAE 参数) 0.95 偏差-方差平衡
学习率 \(3 \times 10^{-4}\) Adam 优化器
Mini-batch 大小 64 ~ 256 视 GPU 显存而定
训练轮数(epoch) 3 ~ 10 每批数据重复优化次数
熵系数 \(c_2\) 0.01 太大过度探索,太小过早收敛
价值损失系数 \(c_1\) 0.5 通常固定
梯度裁剪 0.5 防止梯度爆炸
收集步数(T) 2048 每次交互收集的样本数

六、PPO 在 RLHF 中的应用

PPO 是大语言模型人类反馈强化学习(RLHF)的核心组件:

预训练 LLM
    │ SFT(监督微调)
SFT 模型
    │ 训练奖励模型 RM(人类标注排序数据)
    │ PPO 优化(RM 作为奖励信号)
对齐后的 LLM(如 ChatGPT)

在 RLHF 场景中:

  • Actor:待训练的 LLM
  • Critic:评估 LLM 输出质量的价值网络
  • 环境:给定 prompt,生成回答
  • 奖励:奖励模型 RM 的打分
  • PPO 的作用:让 LLM 的输出既符合人类偏好,又不偏离 SFT 模型太远

为什么 RLHF 选 PPO?

  1. PPO 的 Clip 机制天然防止模型偏离太远(不会"忘记"预训练知识)
  2. 实现简单,易于大规模分布式训练
  3. 经过大量实验验证,稳定性好

七、PPO 与其他算法对比

REINFORCE A2C TRPO PPO
更新方式 整条轨迹 单步 TD 信任域约束 Clip 截断
数据效率 低(一次性) 高(可重用)
实现难度 简单 中等 困难 简单
稳定性 一般
工业应用 较少 较少 ✅ 主流
连续动作

关键公式速查

名称 公式
概率比 \(r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_{\text{old}}}(a_t \mid s_t)}\)
PPO-Clip \(L = \mathbb{E}\left[\min(r_t \hat{A}_t, \; \text{clip}(r_t, 1 \pm \epsilon) \hat{A}_t)\right]\)
GAE \(\hat{A}_t = \sum_{l=0}^{\infty}(\gamma\lambda)^l \delta_{t+l}\)
总损失 \(L = L^{\text{CLIP}} - c_1 L^{\text{VF}} + c_2 S[\pi_\theta]\)