TRPO 与自然梯度
概述
TRPO (Trust Region Policy Optimization) 是策略优化中的里程碑算法,通过信赖域约束确保每次策略更新都是安全的。其理论基础是 自然梯度 (Natural Gradient),利用 Fisher 信息矩阵度量参数空间中的"真实距离"。
本文从自然梯度的数学原理出发,推导 TRPO 的优化目标,并讨论其与 PPO 的关系。
1. 策略优化的挑战
1.1 为什么普通梯度不够好?
标准梯度下降更新:
问题: 参数空间中的小变化可能导致策略空间中的巨大变化(反之亦然)。
直觉例子
考虑 softmax 策略 \(\pi_\theta(a|s) = \frac{e^{\theta_a}}{\sum_b e^{\theta_b}}\):
- 当所有 \(\theta\) 接近时,微小参数变化导致策略剧变
- 当某个 \(\theta\) 远大于其他时,大参数变化几乎不改变策略
梯度下降对这种不一致性视而不见。
1.2 我们需要什么?
- 每次更新后,新策略与旧策略在行为上的差异是可控的
- 更新步长应该在策略空间而非参数空间中度量
- 保证策略性能单调不降(或至少不降太多)
2. 自然梯度
2.1 Fisher 信息矩阵
Fisher 信息矩阵 定义为策略对数似然梯度的外积期望:
性质:
- \(F\) 是半正定矩阵
- \(F\) 度量参数变化导致的策略分布变化
- \(F\) 是 KL 散度的二阶近似: \(D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\)
2.2 从 KL 散度到 Fisher 矩阵
对 KL 散度在 \(\theta\) 处做二阶 Taylor 展开:
由于: - \(D_{KL}(\pi_\theta \| \pi_\theta) = 0\) - \(\nabla_\theta D_{KL} \big|_{\Delta\theta=0} = 0\) (最小值处梯度为零) - \(H = F\) (Hessian 等于 Fisher 矩阵)
所以:
2.3 自然梯度推导
目标: 在 KL 散度约束下最大化目标函数。
使用 Lagrange 乘子法:
对 \(\Delta\theta\) 求导并令其为零:
自然梯度:
2.4 自然梯度 vs 普通梯度
| 维度 | 普通梯度 | 自然梯度 |
|---|---|---|
| 度量 | 欧氏距离 \(\|\Delta\theta\|_2\) | KL 散度 \(D_{KL}(\pi_\theta \| \pi_{\theta'})\) |
| 步长 | 参数空间固定 | 策略空间固定 |
| 参数化不变性 | 无 | 有 (对参数化方式不敏感) |
| 计算成本 | \(O(n)\) | \(O(n^2)\) 或更高 |
| 收敛行为 | 可能振荡 | 更加平稳 |
参数化不变性
自然梯度的一个关键优势是它对策略的参数化方式不敏感。无论你如何参数化策略 \(\pi_\theta\),自然梯度在策略空间中的行为是一致的。
3. TRPO: 信赖域策略优化
3.1 理论基础: 策略性能界
Kakade & Langford (2002) 的重要结果:
其中: - \(A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s)\) 是优势函数 - \(d^{\pi'}(s)\) 是新策略的状态访问分布 - \(C\) 是依赖于折扣因子的常数 - \(D_{KL}^{\max} = \max_s D_{KL}(\pi(\cdot|s) \| \pi'(\cdot|s))\)
问题: 这个界需要 \(d^{\pi'}(s)\),但新策略还没有执行,无法获得。
3.2 TRPO 的近似
TRPO 做了两个关键近似:
- 用旧策略分布替代: \(d^{\pi'}(s) \approx d^{\pi}(s)\)
- 用平均 KL 替代最大 KL: \(D_{KL}^{\max} \to \bar{D}_{KL}\)
得到代理目标:
其中 \(\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)}\) 是重要性采样比率。
3.3 优化方法: 共轭梯度 + 线搜索
直接求解带约束优化问题计算量巨大(需要计算和存储 \(F^{-1}\))。TRPO 使用两步近似:
第一步: 共轭梯度法 (Conjugate Gradient)
求解 \(F \cdot x = g\) 而不显式计算 \(F^{-1}\):
其中 \(g = \nabla_\theta L(\theta)\) 是代理目标的梯度。
关键: 只需要计算 Fisher 矩阵与向量的乘积 \(Fv\),可以通过自动微分高效计算。
第二步: 线搜索 (Line Search)
计算更新方向后,通过回溯线搜索确定步长:
其中 \(j\) 是使以下条件满足的最小整数: 1. \(L(\theta_{new}) \geq L(\theta_{old})\) (性能改善) 2. \(\bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_{\theta_{new}}) \leq \delta\) (满足约束)
3.4 TRPO 算法流程
repeat:
1. 用当前策略 π_θ 收集轨迹数据
2. 估计优势函数 A(s,a) (通常用 GAE)
3. 计算策略梯度 g = ∇_θ L(θ)
4. 用共轭梯度法求解 x ≈ F⁻¹g
5. 计算最大步长 β = √(2δ / x^T F x)
6. 回溯线搜索: θ_new = θ_old + α^j · β · x
直到满足改进条件和 KL 约束
7. 更新 θ ← θ_new
3.5 广义优势估计 (GAE)
TRPO 通常与 GAE (Generalized Advantage Estimation) 结合使用:
其中 \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) 是 TD 误差。
\(\lambda\) 控制偏差-方差权衡: - \(\lambda = 0\): 纯 TD 估计(低方差,高偏差) - \(\lambda = 1\): 蒙特卡洛估计(高方差,低偏差) - 典型值: \(\lambda = 0.95 \sim 0.97\)
4. 数学推导细节
4.1 重要性采样与代理目标
新策略的期望回报可以用旧策略的数据估计:
用重要性采样:
4.2 KL 约束的二阶近似
代理目标在 \(\theta = \theta_{old}\) 处的一阶展开:
KL 约束的二阶展开:
于是问题变为:
解为:
4.3 Fisher 向量积的高效计算
不需要显式构建 \(F\) 矩阵。利用自动微分:
- 计算 \(v_1 = \nabla_\theta D_{KL}(\pi_{\theta_{old}} \| \pi_\theta)\)
- 计算 \(Fv = \nabla_\theta (v_1^T v)\)
这只需要两次反向传播,空间复杂度为 \(O(n)\),其中 \(n\) 是参数数量。
5. 从 TRPO 到 PPO
5.1 TRPO 的局限
- 实现复杂: 共轭梯度 + 线搜索实现繁琐
- 计算昂贵: 每步需要多次前向/反向传播
- 不兼容共享架构: 难以与共享参数的 Actor-Critic 网络结合
- 不兼容 dropout/噪声: 约束优化与随机正则化冲突
5.2 PPO 的裁剪替代
PPO 用裁剪目标替代硬约束:
其中 \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\)。
直觉: - 当 \(\hat{A}_t > 0\) (好动作): 限制 \(r_t\) 不超过 \(1+\epsilon\),防止过度增大概率 - 当 \(\hat{A}_t < 0\) (坏动作): 限制 \(r_t\) 不低于 \(1-\epsilon\),防止过度减小概率
5.3 TRPO vs PPO 对比
| 维度 | TRPO | PPO (Clip) |
|---|---|---|
| 约束方式 | 硬约束 (KL ≤ δ) | 软约束 (裁剪) |
| 优化方法 | 共轭梯度 + 线搜索 | 标准 SGD / Adam |
| 实现难度 | 高 | 低 |
| 计算成本 | 高 | 低 |
| 理论保证 | 有 (单调改进) | 弱 (经验上有效) |
| 实际性能 | 好 | 好(通常更好) |
| 多 epoch 更新 | 不适用 | 适用 |
| 工业应用 | 少 | 广泛 (RLHF 等) |
实践建议
在绝大多数场景中,推荐使用 PPO 而非 TRPO。PPO 更简单、更快、通常效果更好。
TRPO 的价值在于理论洞见:理解为什么策略约束很重要,以及自然梯度的几何意义。
5.4 PPO-KL 变体
PPO 还有一个使用自适应 KL 惩罚的变体:
其中 \(\beta\) 根据实际 KL 散度自适应调整: - 如果 \(D_{KL} > d_{targ} \times 1.5\): 增大 \(\beta\) - 如果 \(D_{KL} < d_{targ} / 1.5\): 减小 \(\beta\)
6. 实现要点
6.1 共轭梯度法实现
def conjugate_gradient(Fvp_func, g, num_steps=10, residual_tol=1e-10):
"""
求解 Fx = g,其中 Fvp_func 计算 F @ v
"""
x = torch.zeros_like(g)
r = g.clone()
p = g.clone()
rdotr = torch.dot(r, r)
for i in range(num_steps):
Fp = Fvp_func(p)
alpha = rdotr / (torch.dot(p, Fp) + 1e-8)
x += alpha * p
r -= alpha * Fp
new_rdotr = torch.dot(r, r)
if new_rdotr < residual_tol:
break
beta = new_rdotr / rdotr
p = r + beta * p
rdotr = new_rdotr
return x
6.2 Fisher 向量积
def fisher_vector_product(policy, states, v, damping=0.1):
"""
计算 F @ v,其中 F 是 Fisher 信息矩阵
"""
kl = compute_kl(policy, old_policy, states)
kl_grad = torch.autograd.grad(kl, policy.parameters(),
create_graph=True)
kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])
kl_v = torch.dot(kl_grad_flat, v)
kl_v_grad = torch.autograd.grad(kl_v, policy.parameters())
fvp = torch.cat([g.view(-1) for g in kl_v_grad])
return fvp + damping * v # 加阻尼提高数值稳定性
6.3 超参数建议
| 超参数 | 典型值 | 说明 |
|---|---|---|
| \(\delta\) | 0.01 | KL 散度约束上界 |
| CG 步数 | 10 | 共轭梯度迭代次数 |
| 线搜索回溯系数 | 0.5 | 每次减半步长 |
| 最大线搜索步数 | 10 | 最多尝试次数 |
| GAE \(\lambda\) | 0.97 | 优势估计参数 |
| \(\gamma\) | 0.99 | 折扣因子 |
| 阻尼系数 | 0.1 | FVP 阻尼 |
7. 理论意义与历史地位
7.1 TRPO 的贡献
- 单调改进保证: 首次为深度策略优化提供理论保证
- 自然梯度的实用化: 通过 CG 和 FVP 使自然梯度可扩展
- 信赖域范式: 启发了 PPO 和后续大量工作
- 策略空间思维: 从参数空间转向策略空间的优化视角
7.2 发展脉络
自然梯度 (Amari, 1998)
→ Natural Actor-Critic (Peters & Schaal, 2008)
→ TRPO (Schulman et al., 2015)
→ PPO (Schulman et al., 2017)
→ RLHF/PPO (Ouyang et al., 2022)
→ GRPO (Shao et al., 2024)
参考资料
- Amari, S. (1998). Natural Gradient Works Efficiently in Learning. Neural Computation.
- Kakade, S. (2001). A Natural Policy Gradient. NeurIPS.
- Kakade, S. & Langford, J. (2002). Approximately Optimal Approximate Reinforcement Learning. ICML.
- Schulman, J. et al. (2015). Trust Region Policy Optimization. ICML.
- Schulman, J. et al. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation.
- Schulman, J. et al. (2017). Proximal Policy Optimization Algorithms.
延伸阅读
- PPO 算法 — TRPO 的实用化继承者
- 策略梯度方法 — 策略梯度基础
- SAC 算法 — 最大熵框架
- TD3 与 DDPG — 确定性策略梯度方法