跳转至

TRPO 与自然梯度

概述

TRPO (Trust Region Policy Optimization) 是策略优化中的里程碑算法,通过信赖域约束确保每次策略更新都是安全的。其理论基础是 自然梯度 (Natural Gradient),利用 Fisher 信息矩阵度量参数空间中的"真实距离"。

本文从自然梯度的数学原理出发,推导 TRPO 的优化目标,并讨论其与 PPO 的关系。


1. 策略优化的挑战

1.1 为什么普通梯度不够好?

标准梯度下降更新:

\[\theta_{t+1} = \theta_t + \alpha \nabla_\theta J(\theta_t)\]

问题: 参数空间中的小变化可能导致策略空间中的巨大变化(反之亦然)。

直觉例子

考虑 softmax 策略 \(\pi_\theta(a|s) = \frac{e^{\theta_a}}{\sum_b e^{\theta_b}}\):

  • 当所有 \(\theta\) 接近时,微小参数变化导致策略剧变
  • 当某个 \(\theta\) 远大于其他时,大参数变化几乎不改变策略

梯度下降对这种不一致性视而不见。

1.2 我们需要什么?

  • 每次更新后,新策略与旧策略在行为上的差异是可控的
  • 更新步长应该在策略空间而非参数空间中度量
  • 保证策略性能单调不降(或至少不降太多)

2. 自然梯度

2.1 Fisher 信息矩阵

Fisher 信息矩阵 定义为策略对数似然梯度的外积期望:

\[F = \mathbb{E}_{s \sim d^\pi, a \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]\]

性质:

  • \(F\) 是半正定矩阵
  • \(F\) 度量参数变化导致的策略分布变化
  • \(F\) 是 KL 散度的二阶近似: \(D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\)

2.2 从 KL 散度到 Fisher 矩阵

对 KL 散度在 \(\theta\) 处做二阶 Taylor 展开:

\[D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx D_{KL}(\pi_\theta \| \pi_\theta) + \nabla_\theta D_{KL} \big|_{\Delta\theta=0} \cdot \Delta\theta + \frac{1}{2} \Delta\theta^T H \Delta\theta\]

由于: - \(D_{KL}(\pi_\theta \| \pi_\theta) = 0\) - \(\nabla_\theta D_{KL} \big|_{\Delta\theta=0} = 0\) (最小值处梯度为零) - \(H = F\) (Hessian 等于 Fisher 矩阵)

所以:

\[D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\]

2.3 自然梯度推导

目标: 在 KL 散度约束下最大化目标函数。

\[\max_{\Delta\theta} \nabla_\theta J(\theta)^T \Delta\theta \quad \text{s.t.} \quad \frac{1}{2} \Delta\theta^T F \Delta\theta \leq \delta\]

使用 Lagrange 乘子法:

\[\mathcal{L}(\Delta\theta, \lambda) = \nabla_\theta J^T \Delta\theta - \lambda \left( \frac{1}{2} \Delta\theta^T F \Delta\theta - \delta \right)\]

\(\Delta\theta\) 求导并令其为零:

\[\nabla_\theta J - \lambda F \Delta\theta = 0\]
\[\Delta\theta = \frac{1}{\lambda} F^{-1} \nabla_\theta J\]

自然梯度:

\[\boxed{\Delta\theta_{\text{natural}} = F^{-1} \nabla_\theta J(\theta)}\]

2.4 自然梯度 vs 普通梯度

维度 普通梯度 自然梯度
度量 欧氏距离 \(\|\Delta\theta\|_2\) KL 散度 \(D_{KL}(\pi_\theta \| \pi_{\theta'})\)
步长 参数空间固定 策略空间固定
参数化不变性 有 (对参数化方式不敏感)
计算成本 \(O(n)\) \(O(n^2)\) 或更高
收敛行为 可能振荡 更加平稳

参数化不变性

自然梯度的一个关键优势是它对策略的参数化方式不敏感。无论你如何参数化策略 \(\pi_\theta\),自然梯度在策略空间中的行为是一致的。


3. TRPO: 信赖域策略优化

3.1 理论基础: 策略性能界

Kakade & Langford (2002) 的重要结果:

\[J(\pi') \geq J(\pi) + \sum_s d^{\pi'}(s) \sum_a \pi'(a|s) A^\pi(s,a) - C \cdot D_{KL}^{\max}(\pi \| \pi')\]

其中: - \(A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s)\) 是优势函数 - \(d^{\pi'}(s)\) 是新策略的状态访问分布 - \(C\) 是依赖于折扣因子的常数 - \(D_{KL}^{\max} = \max_s D_{KL}(\pi(\cdot|s) \| \pi'(\cdot|s))\)

问题: 这个界需要 \(d^{\pi'}(s)\),但新策略还没有执行,无法获得。

3.2 TRPO 的近似

TRPO 做了两个关键近似:

  1. 用旧策略分布替代: \(d^{\pi'}(s) \approx d^{\pi}(s)\)
  2. 用平均 KL 替代最大 KL: \(D_{KL}^{\max} \to \bar{D}_{KL}\)

得到代理目标:

\[\max_\theta L(\theta) = \mathbb{E}_{s \sim d^{\pi_{\theta_{old}}}, a \sim \pi_{\theta_{old}}} \left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \right]\]
\[\text{s.t.} \quad \bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \leq \delta\]

其中 \(\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)}\) 是重要性采样比率。

3.3 优化方法: 共轭梯度 + 线搜索

直接求解带约束优化问题计算量巨大(需要计算和存储 \(F^{-1}\))。TRPO 使用两步近似:

第一步: 共轭梯度法 (Conjugate Gradient)

求解 \(F \cdot x = g\) 而不显式计算 \(F^{-1}\):

\[x \approx F^{-1} g\]

其中 \(g = \nabla_\theta L(\theta)\) 是代理目标的梯度。

关键: 只需要计算 Fisher 矩阵与向量的乘积 \(Fv\),可以通过自动微分高效计算。

\[Fv = \nabla_\theta \left[ (\nabla_\theta D_{KL})^T v \right]\]

计算更新方向后,通过回溯线搜索确定步长:

\[\theta_{new} = \theta_{old} + \alpha^j \sqrt{\frac{2\delta}{x^T F x}} \cdot x\]

其中 \(j\) 是使以下条件满足的最小整数: 1. \(L(\theta_{new}) \geq L(\theta_{old})\) (性能改善) 2. \(\bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_{\theta_{new}}) \leq \delta\) (满足约束)

3.4 TRPO 算法流程

repeat:
    1. 用当前策略 π_θ 收集轨迹数据
    2. 估计优势函数 A(s,a) (通常用 GAE)
    3. 计算策略梯度 g = ∇_θ L(θ)
    4. 用共轭梯度法求解 x ≈ F⁻¹g
    5. 计算最大步长 β = √(2δ / x^T F x)
    6. 回溯线搜索: θ_new = θ_old + α^j · β · x
       直到满足改进条件和 KL 约束
    7. 更新 θ ← θ_new

3.5 广义优势估计 (GAE)

TRPO 通常与 GAE (Generalized Advantage Estimation) 结合使用:

\[\hat{A}_t^{GAE(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}\]

其中 \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) 是 TD 误差。

\(\lambda\) 控制偏差-方差权衡: - \(\lambda = 0\): 纯 TD 估计(低方差,高偏差) - \(\lambda = 1\): 蒙特卡洛估计(高方差,低偏差) - 典型值: \(\lambda = 0.95 \sim 0.97\)


4. 数学推导细节

4.1 重要性采样与代理目标

新策略的期望回报可以用旧策略的数据估计:

\[J(\pi_\theta) = J(\pi_{\theta_{old}}) + \mathbb{E}_{s \sim d^{\pi_\theta}} \left[ \sum_a \pi_\theta(a|s) A^{\pi_{\theta_{old}}}(s,a) \right]\]

用重要性采样:

\[= J(\pi_{\theta_{old}}) + \mathbb{E}_{s \sim d^{\pi_{\theta_{old}}}, a \sim \pi_{\theta_{old}}} \left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \right]\]

4.2 KL 约束的二阶近似

代理目标在 \(\theta = \theta_{old}\) 处的一阶展开:

\[L(\theta) \approx g^T (\theta - \theta_{old})\]

KL 约束的二阶展开:

\[\bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \approx \frac{1}{2} (\theta - \theta_{old})^T F (\theta - \theta_{old})\]

于是问题变为:

\[\max_{\Delta\theta} g^T \Delta\theta \quad \text{s.t.} \quad \frac{1}{2} \Delta\theta^T F \Delta\theta \leq \delta\]

解为:

\[\Delta\theta^* = \sqrt{\frac{2\delta}{g^T F^{-1} g}} F^{-1} g\]

4.3 Fisher 向量积的高效计算

不需要显式构建 \(F\) 矩阵。利用自动微分:

  1. 计算 \(v_1 = \nabla_\theta D_{KL}(\pi_{\theta_{old}} \| \pi_\theta)\)
  2. 计算 \(Fv = \nabla_\theta (v_1^T v)\)

这只需要两次反向传播,空间复杂度为 \(O(n)\),其中 \(n\) 是参数数量。


5. 从 TRPO 到 PPO

5.1 TRPO 的局限

  1. 实现复杂: 共轭梯度 + 线搜索实现繁琐
  2. 计算昂贵: 每步需要多次前向/反向传播
  3. 不兼容共享架构: 难以与共享参数的 Actor-Critic 网络结合
  4. 不兼容 dropout/噪声: 约束优化与随机正则化冲突

5.2 PPO 的裁剪替代

PPO 用裁剪目标替代硬约束:

\[L^{CLIP}(\theta) = \mathbb{E} \left[ \min\left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right]\]

其中 \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\)

直觉: - 当 \(\hat{A}_t > 0\) (好动作): 限制 \(r_t\) 不超过 \(1+\epsilon\),防止过度增大概率 - 当 \(\hat{A}_t < 0\) (坏动作): 限制 \(r_t\) 不低于 \(1-\epsilon\),防止过度减小概率

5.3 TRPO vs PPO 对比

维度 TRPO PPO (Clip)
约束方式 硬约束 (KL ≤ δ) 软约束 (裁剪)
优化方法 共轭梯度 + 线搜索 标准 SGD / Adam
实现难度
计算成本
理论保证 有 (单调改进) 弱 (经验上有效)
实际性能 好(通常更好)
多 epoch 更新 不适用 适用
工业应用 广泛 (RLHF 等)

实践建议

在绝大多数场景中,推荐使用 PPO 而非 TRPO。PPO 更简单、更快、通常效果更好。

TRPO 的价值在于理论洞见:理解为什么策略约束很重要,以及自然梯度的几何意义。

5.4 PPO-KL 变体

PPO 还有一个使用自适应 KL 惩罚的变体:

\[L^{KL}(\theta) = \mathbb{E} \left[ r_t(\theta) \hat{A}_t - \beta D_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \right]\]

其中 \(\beta\) 根据实际 KL 散度自适应调整: - 如果 \(D_{KL} > d_{targ} \times 1.5\): 增大 \(\beta\) - 如果 \(D_{KL} < d_{targ} / 1.5\): 减小 \(\beta\)


6. 实现要点

6.1 共轭梯度法实现

def conjugate_gradient(Fvp_func, g, num_steps=10, residual_tol=1e-10):
    """
    求解 Fx = g,其中 Fvp_func 计算 F @ v
    """
    x = torch.zeros_like(g)
    r = g.clone()
    p = g.clone()
    rdotr = torch.dot(r, r)

    for i in range(num_steps):
        Fp = Fvp_func(p)
        alpha = rdotr / (torch.dot(p, Fp) + 1e-8)
        x += alpha * p
        r -= alpha * Fp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr

    return x

6.2 Fisher 向量积

def fisher_vector_product(policy, states, v, damping=0.1):
    """
    计算 F @ v,其中 F 是 Fisher 信息矩阵
    """
    kl = compute_kl(policy, old_policy, states)
    kl_grad = torch.autograd.grad(kl, policy.parameters(),
                                   create_graph=True)
    kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])

    kl_v = torch.dot(kl_grad_flat, v)
    kl_v_grad = torch.autograd.grad(kl_v, policy.parameters())
    fvp = torch.cat([g.view(-1) for g in kl_v_grad])

    return fvp + damping * v  # 加阻尼提高数值稳定性

6.3 超参数建议

超参数 典型值 说明
\(\delta\) 0.01 KL 散度约束上界
CG 步数 10 共轭梯度迭代次数
线搜索回溯系数 0.5 每次减半步长
最大线搜索步数 10 最多尝试次数
GAE \(\lambda\) 0.97 优势估计参数
\(\gamma\) 0.99 折扣因子
阻尼系数 0.1 FVP 阻尼

7. 理论意义与历史地位

7.1 TRPO 的贡献

  1. 单调改进保证: 首次为深度策略优化提供理论保证
  2. 自然梯度的实用化: 通过 CG 和 FVP 使自然梯度可扩展
  3. 信赖域范式: 启发了 PPO 和后续大量工作
  4. 策略空间思维: 从参数空间转向策略空间的优化视角

7.2 发展脉络

自然梯度 (Amari, 1998)
  → Natural Actor-Critic (Peters & Schaal, 2008)
    → TRPO (Schulman et al., 2015)
      → PPO (Schulman et al., 2017)
        → RLHF/PPO (Ouyang et al., 2022)
          → GRPO (Shao et al., 2024)

参考资料

  • Amari, S. (1998). Natural Gradient Works Efficiently in Learning. Neural Computation.
  • Kakade, S. (2001). A Natural Policy Gradient. NeurIPS.
  • Kakade, S. & Langford, J. (2002). Approximately Optimal Approximate Reinforcement Learning. ICML.
  • Schulman, J. et al. (2015). Trust Region Policy Optimization. ICML.
  • Schulman, J. et al. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation.
  • Schulman, J. et al. (2017). Proximal Policy Optimization Algorithms.

延伸阅读


评论 #