TRPO and Natural Gradient
Overview
TRPO (Trust Region Policy Optimization) is a milestone algorithm in policy optimization that ensures each policy update is safe through trust region constraints. Its theoretical foundation is the Natural Gradient, which uses the Fisher information matrix to measure "true distance" in parameter space.
This article starts from the mathematical principles of natural gradients, derives TRPO's optimization objective, and discusses its relationship with PPO.
1. Challenges in Policy Optimization
1.1 Why Is the Standard Gradient Not Enough?
Standard gradient descent update:
Problem: Small changes in parameter space can cause dramatic changes in policy space (and vice versa).
Intuitive Example
Consider a softmax policy \(\pi_\theta(a|s) = \frac{e^{\theta_a}}{\sum_b e^{\theta_b}}\):
- When all \(\theta\) values are close, tiny parameter changes cause dramatic policy shifts
- When one \(\theta\) is much larger, large parameter changes barely affect the policy
Gradient descent is blind to this inconsistency.
1.2 What Do We Need?
- After each update, the behavioral difference between new and old policies must be controllable
- Step size should be measured in policy space, not parameter space
- Guarantee monotonic non-decrease in policy performance (or at least bounded decrease)
2. Natural Gradient
2.1 Fisher Information Matrix
The Fisher information matrix is defined as the expected outer product of the policy log-likelihood gradient:
Properties:
- \(F\) is positive semi-definite
- \(F\) measures the policy distribution change caused by parameter changes
- \(F\) is a second-order approximation to KL divergence: \(D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\)
2.2 From KL Divergence to Fisher Matrix
Second-order Taylor expansion of KL divergence at \(\theta\):
Since: - \(D_{KL}(\pi_\theta \| \pi_\theta) = 0\) - \(\nabla_\theta D_{KL} \big|_{\Delta\theta=0} = 0\) (gradient is zero at the minimum) - \(H = F\) (Hessian equals Fisher matrix)
Therefore:
2.3 Natural Gradient Derivation
Objective: Maximize the objective function subject to a KL divergence constraint.
Using Lagrange multipliers:
Taking the derivative with respect to \(\Delta\theta\) and setting it to zero:
Natural gradient:
2.4 Natural Gradient vs Standard Gradient
| Dimension | Standard Gradient | Natural Gradient |
|---|---|---|
| Metric | Euclidean distance \(\|\Delta\theta\|_2\) | KL divergence \(D_{KL}(\pi_\theta \| \pi_{\theta'})\) |
| Step size | Fixed in parameter space | Fixed in policy space |
| Parameterization invariance | No | Yes (insensitive to parameterization) |
| Computational cost | \(O(n)\) | \(O(n^2)\) or higher |
| Convergence behavior | May oscillate | Smoother |
Parameterization Invariance
A key advantage of the natural gradient is its insensitivity to the policy's parameterization. Regardless of how you parameterize the policy \(\pi_\theta\), the natural gradient's behavior in policy space is consistent.
3. TRPO: Trust Region Policy Optimization
3.1 Theoretical Foundation: Policy Performance Bound
The key result from Kakade & Langford (2002):
where: - \(A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s)\) is the advantage function - \(d^{\pi'}(s)\) is the state visitation distribution of the new policy - \(C\) is a constant depending on the discount factor - \(D_{KL}^{\max} = \max_s D_{KL}(\pi(\cdot|s) \| \pi'(\cdot|s))\)
Problem: This bound requires \(d^{\pi'}(s)\), but the new policy hasn't been executed yet.
3.2 TRPO's Approximations
TRPO makes two key approximations:
- Substitute old policy distribution: \(d^{\pi'}(s) \approx d^{\pi}(s)\)
- Replace max KL with average KL: \(D_{KL}^{\max} \to \bar{D}_{KL}\)
Yielding the surrogate objective:
where \(\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)}\) is the importance sampling ratio.
3.3 Optimization: Conjugate Gradient + Line Search
Directly solving the constrained optimization problem is computationally expensive (requires computing and storing \(F^{-1}\)). TRPO uses a two-step approximation:
Step 1: Conjugate Gradient (CG)
Solve \(F \cdot x = g\) without explicitly computing \(F^{-1}\):
where \(g = \nabla_\theta L(\theta)\) is the gradient of the surrogate objective.
Key insight: Only the Fisher matrix-vector product \(Fv\) is needed, which can be computed efficiently via automatic differentiation:
Step 2: Line Search
After computing the update direction, determine the step size through backtracking line search:
where \(j\) is the smallest integer satisfying: 1. \(L(\theta_{new}) \geq L(\theta_{old})\) (performance improvement) 2. \(\bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_{\theta_{new}}) \leq \delta\) (constraint satisfaction)
3.4 TRPO Algorithm Flow
repeat:
1. Collect trajectory data using current policy π_θ
2. Estimate advantage function A(s,a) (typically using GAE)
3. Compute policy gradient g = ∇_θ L(θ)
4. Solve x ≈ F⁻¹g using conjugate gradient
5. Compute maximum step size β = √(2δ / x^T F x)
6. Backtracking line search: θ_new = θ_old + α^j · β · x
until improvement condition and KL constraint are satisfied
7. Update θ ← θ_new
3.5 Generalized Advantage Estimation (GAE)
TRPO is typically combined with GAE (Generalized Advantage Estimation):
where \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) is the TD error.
\(\lambda\) controls the bias-variance tradeoff: - \(\lambda = 0\): Pure TD estimate (low variance, high bias) - \(\lambda = 1\): Monte Carlo estimate (high variance, low bias) - Typical values: \(\lambda = 0.95 \sim 0.97\)
4. Mathematical Derivation Details
4.1 Importance Sampling and Surrogate Objective
The expected return of the new policy can be estimated using old policy data:
Using importance sampling:
4.2 Second-Order Approximation of KL Constraint
First-order expansion of the surrogate objective at \(\theta = \theta_{old}\):
Second-order expansion of the KL constraint:
The problem becomes:
The solution is:
4.3 Efficient Fisher-Vector Product Computation
No need to explicitly construct the \(F\) matrix. Using automatic differentiation:
- Compute \(v_1 = \nabla_\theta D_{KL}(\pi_{\theta_{old}} \| \pi_\theta)\)
- Compute \(Fv = \nabla_\theta (v_1^T v)\)
This requires only two backward passes, with space complexity \(O(n)\) where \(n\) is the number of parameters.
5. From TRPO to PPO
5.1 TRPO's Limitations
- Implementation complexity: Conjugate gradient + line search is cumbersome
- Computational expense: Multiple forward/backward passes per step
- Incompatible with shared architectures: Difficult to combine with shared-parameter Actor-Critic networks
- Incompatible with dropout/noise: Constrained optimization conflicts with stochastic regularization
5.2 PPO's Clipped Alternative
PPO replaces the hard constraint with a clipped objective:
where \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\).
Intuition: - When \(\hat{A}_t > 0\) (good action): Limit \(r_t\) to not exceed \(1+\epsilon\), preventing excessive probability increase - When \(\hat{A}_t < 0\) (bad action): Limit \(r_t\) to not go below \(1-\epsilon\), preventing excessive probability decrease
5.3 TRPO vs PPO Comparison
| Dimension | TRPO | PPO (Clip) |
|---|---|---|
| Constraint type | Hard constraint (KL ≤ δ) | Soft constraint (clipping) |
| Optimization method | Conjugate gradient + line search | Standard SGD / Adam |
| Implementation difficulty | High | Low |
| Computational cost | High | Low |
| Theoretical guarantee | Yes (monotonic improvement) | Weak (empirically effective) |
| Practical performance | Good | Good (often better) |
| Multi-epoch updates | Not applicable | Applicable |
| Industrial applications | Few | Widespread (RLHF, etc.) |
Practical Recommendation
In the vast majority of scenarios, PPO is recommended over TRPO. PPO is simpler, faster, and usually performs better.
TRPO's value lies in its theoretical insights: understanding why policy constraints matter and the geometric meaning of natural gradients.
5.4 PPO-KL Variant
PPO also has a variant using adaptive KL penalty:
where \(\beta\) is adaptively adjusted based on actual KL divergence: - If \(D_{KL} > d_{targ} \times 1.5\): Increase \(\beta\) - If \(D_{KL} < d_{targ} / 1.5\): Decrease \(\beta\)
6. Implementation Details
6.1 Conjugate Gradient Implementation
def conjugate_gradient(Fvp_func, g, num_steps=10, residual_tol=1e-10):
"""
Solve Fx = g, where Fvp_func computes F @ v
"""
x = torch.zeros_like(g)
r = g.clone()
p = g.clone()
rdotr = torch.dot(r, r)
for i in range(num_steps):
Fp = Fvp_func(p)
alpha = rdotr / (torch.dot(p, Fp) + 1e-8)
x += alpha * p
r -= alpha * Fp
new_rdotr = torch.dot(r, r)
if new_rdotr < residual_tol:
break
beta = new_rdotr / rdotr
p = r + beta * p
rdotr = new_rdotr
return x
6.2 Fisher-Vector Product
def fisher_vector_product(policy, states, v, damping=0.1):
"""
Compute F @ v, where F is the Fisher information matrix
"""
kl = compute_kl(policy, old_policy, states)
kl_grad = torch.autograd.grad(kl, policy.parameters(),
create_graph=True)
kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])
kl_v = torch.dot(kl_grad_flat, v)
kl_v_grad = torch.autograd.grad(kl_v, policy.parameters())
fvp = torch.cat([g.view(-1) for g in kl_v_grad])
return fvp + damping * v # Add damping for numerical stability
6.3 Hyperparameter Recommendations
| Hyperparameter | Typical Value | Description |
|---|---|---|
| \(\delta\) | 0.01 | KL divergence constraint upper bound |
| CG steps | 10 | Conjugate gradient iterations |
| Line search backtrack coeff | 0.5 | Halve step size each time |
| Max line search steps | 10 | Maximum attempts |
| GAE \(\lambda\) | 0.97 | Advantage estimation parameter |
| \(\gamma\) | 0.99 | Discount factor |
| Damping coefficient | 0.1 | FVP damping |
7. Theoretical Significance and Historical Position
7.1 TRPO's Contributions
- Monotonic improvement guarantee: First theoretical guarantee for deep policy optimization
- Practical natural gradient: Made natural gradient scalable through CG and FVP
- Trust region paradigm: Inspired PPO and a vast body of subsequent work
- Policy-space thinking: Shifted optimization perspective from parameter space to policy space
7.2 Development Lineage
Natural Gradient (Amari, 1998)
→ Natural Actor-Critic (Peters & Schaal, 2008)
→ TRPO (Schulman et al., 2015)
→ PPO (Schulman et al., 2017)
→ RLHF/PPO (Ouyang et al., 2022)
→ GRPO (Shao et al., 2024)
References
- Amari, S. (1998). Natural Gradient Works Efficiently in Learning. Neural Computation.
- Kakade, S. (2001). A Natural Policy Gradient. NeurIPS.
- Kakade, S. & Langford, J. (2002). Approximately Optimal Approximate Reinforcement Learning. ICML.
- Schulman, J. et al. (2015). Trust Region Policy Optimization. ICML.
- Schulman, J. et al. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation.
- Schulman, J. et al. (2017). Proximal Policy Optimization Algorithms.
Further Reading
- PPO Algorithm — TRPO's practical successor
- Policy Gradient Methods — Policy gradient foundations
- SAC Algorithm — Maximum entropy framework
- TD3 and DDPG — Deterministic policy gradient methods