Skip to content

TRPO and Natural Gradient

Overview

TRPO (Trust Region Policy Optimization) is a milestone algorithm in policy optimization that ensures each policy update is safe through trust region constraints. Its theoretical foundation is the Natural Gradient, which uses the Fisher information matrix to measure "true distance" in parameter space.

This article starts from the mathematical principles of natural gradients, derives TRPO's optimization objective, and discusses its relationship with PPO.


1. Challenges in Policy Optimization

1.1 Why Is the Standard Gradient Not Enough?

Standard gradient descent update:

\[\theta_{t+1} = \theta_t + \alpha \nabla_\theta J(\theta_t)\]

Problem: Small changes in parameter space can cause dramatic changes in policy space (and vice versa).

Intuitive Example

Consider a softmax policy \(\pi_\theta(a|s) = \frac{e^{\theta_a}}{\sum_b e^{\theta_b}}\):

  • When all \(\theta\) values are close, tiny parameter changes cause dramatic policy shifts
  • When one \(\theta\) is much larger, large parameter changes barely affect the policy

Gradient descent is blind to this inconsistency.

1.2 What Do We Need?

  • After each update, the behavioral difference between new and old policies must be controllable
  • Step size should be measured in policy space, not parameter space
  • Guarantee monotonic non-decrease in policy performance (or at least bounded decrease)

2. Natural Gradient

2.1 Fisher Information Matrix

The Fisher information matrix is defined as the expected outer product of the policy log-likelihood gradient:

\[F = \mathbb{E}_{s \sim d^\pi, a \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]\]

Properties:

  • \(F\) is positive semi-definite
  • \(F\) measures the policy distribution change caused by parameter changes
  • \(F\) is a second-order approximation to KL divergence: \(D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\)

2.2 From KL Divergence to Fisher Matrix

Second-order Taylor expansion of KL divergence at \(\theta\):

\[D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx D_{KL}(\pi_\theta \| \pi_\theta) + \nabla_\theta D_{KL} \big|_{\Delta\theta=0} \cdot \Delta\theta + \frac{1}{2} \Delta\theta^T H \Delta\theta\]

Since: - \(D_{KL}(\pi_\theta \| \pi_\theta) = 0\) - \(\nabla_\theta D_{KL} \big|_{\Delta\theta=0} = 0\) (gradient is zero at the minimum) - \(H = F\) (Hessian equals Fisher matrix)

Therefore:

\[D_{KL}(\pi_\theta \| \pi_{\theta + \Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\]

2.3 Natural Gradient Derivation

Objective: Maximize the objective function subject to a KL divergence constraint.

\[\max_{\Delta\theta} \nabla_\theta J(\theta)^T \Delta\theta \quad \text{s.t.} \quad \frac{1}{2} \Delta\theta^T F \Delta\theta \leq \delta\]

Using Lagrange multipliers:

\[\mathcal{L}(\Delta\theta, \lambda) = \nabla_\theta J^T \Delta\theta - \lambda \left( \frac{1}{2} \Delta\theta^T F \Delta\theta - \delta \right)\]

Taking the derivative with respect to \(\Delta\theta\) and setting it to zero:

\[\nabla_\theta J - \lambda F \Delta\theta = 0\]
\[\Delta\theta = \frac{1}{\lambda} F^{-1} \nabla_\theta J\]

Natural gradient:

\[\boxed{\Delta\theta_{\text{natural}} = F^{-1} \nabla_\theta J(\theta)}\]

2.4 Natural Gradient vs Standard Gradient

Dimension Standard Gradient Natural Gradient
Metric Euclidean distance \(\|\Delta\theta\|_2\) KL divergence \(D_{KL}(\pi_\theta \| \pi_{\theta'})\)
Step size Fixed in parameter space Fixed in policy space
Parameterization invariance No Yes (insensitive to parameterization)
Computational cost \(O(n)\) \(O(n^2)\) or higher
Convergence behavior May oscillate Smoother

Parameterization Invariance

A key advantage of the natural gradient is its insensitivity to the policy's parameterization. Regardless of how you parameterize the policy \(\pi_\theta\), the natural gradient's behavior in policy space is consistent.


3. TRPO: Trust Region Policy Optimization

3.1 Theoretical Foundation: Policy Performance Bound

The key result from Kakade & Langford (2002):

\[J(\pi') \geq J(\pi) + \sum_s d^{\pi'}(s) \sum_a \pi'(a|s) A^\pi(s,a) - C \cdot D_{KL}^{\max}(\pi \| \pi')\]

where: - \(A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s)\) is the advantage function - \(d^{\pi'}(s)\) is the state visitation distribution of the new policy - \(C\) is a constant depending on the discount factor - \(D_{KL}^{\max} = \max_s D_{KL}(\pi(\cdot|s) \| \pi'(\cdot|s))\)

Problem: This bound requires \(d^{\pi'}(s)\), but the new policy hasn't been executed yet.

3.2 TRPO's Approximations

TRPO makes two key approximations:

  1. Substitute old policy distribution: \(d^{\pi'}(s) \approx d^{\pi}(s)\)
  2. Replace max KL with average KL: \(D_{KL}^{\max} \to \bar{D}_{KL}\)

Yielding the surrogate objective:

\[\max_\theta L(\theta) = \mathbb{E}_{s \sim d^{\pi_{\theta_{old}}}, a \sim \pi_{\theta_{old}}} \left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \right]\]
\[\text{s.t.} \quad \bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \leq \delta\]

where \(\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)}\) is the importance sampling ratio.

Directly solving the constrained optimization problem is computationally expensive (requires computing and storing \(F^{-1}\)). TRPO uses a two-step approximation:

Step 1: Conjugate Gradient (CG)

Solve \(F \cdot x = g\) without explicitly computing \(F^{-1}\):

\[x \approx F^{-1} g\]

where \(g = \nabla_\theta L(\theta)\) is the gradient of the surrogate objective.

Key insight: Only the Fisher matrix-vector product \(Fv\) is needed, which can be computed efficiently via automatic differentiation:

\[Fv = \nabla_\theta \left[ (\nabla_\theta D_{KL})^T v \right]\]

After computing the update direction, determine the step size through backtracking line search:

\[\theta_{new} = \theta_{old} + \alpha^j \sqrt{\frac{2\delta}{x^T F x}} \cdot x\]

where \(j\) is the smallest integer satisfying: 1. \(L(\theta_{new}) \geq L(\theta_{old})\) (performance improvement) 2. \(\bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_{\theta_{new}}) \leq \delta\) (constraint satisfaction)

3.4 TRPO Algorithm Flow

repeat:
    1. Collect trajectory data using current policy π_θ
    2. Estimate advantage function A(s,a) (typically using GAE)
    3. Compute policy gradient g = ∇_θ L(θ)
    4. Solve x ≈ F⁻¹g using conjugate gradient
    5. Compute maximum step size β = √(2δ / x^T F x)
    6. Backtracking line search: θ_new = θ_old + α^j · β · x
       until improvement condition and KL constraint are satisfied
    7. Update θ ← θ_new

3.5 Generalized Advantage Estimation (GAE)

TRPO is typically combined with GAE (Generalized Advantage Estimation):

\[\hat{A}_t^{GAE(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}\]

where \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) is the TD error.

\(\lambda\) controls the bias-variance tradeoff: - \(\lambda = 0\): Pure TD estimate (low variance, high bias) - \(\lambda = 1\): Monte Carlo estimate (high variance, low bias) - Typical values: \(\lambda = 0.95 \sim 0.97\)


4. Mathematical Derivation Details

4.1 Importance Sampling and Surrogate Objective

The expected return of the new policy can be estimated using old policy data:

\[J(\pi_\theta) = J(\pi_{\theta_{old}}) + \mathbb{E}_{s \sim d^{\pi_\theta}} \left[ \sum_a \pi_\theta(a|s) A^{\pi_{\theta_{old}}}(s,a) \right]\]

Using importance sampling:

\[= J(\pi_{\theta_{old}}) + \mathbb{E}_{s \sim d^{\pi_{\theta_{old}}}, a \sim \pi_{\theta_{old}}} \left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \right]\]

4.2 Second-Order Approximation of KL Constraint

First-order expansion of the surrogate objective at \(\theta = \theta_{old}\):

\[L(\theta) \approx g^T (\theta - \theta_{old})\]

Second-order expansion of the KL constraint:

\[\bar{D}_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \approx \frac{1}{2} (\theta - \theta_{old})^T F (\theta - \theta_{old})\]

The problem becomes:

\[\max_{\Delta\theta} g^T \Delta\theta \quad \text{s.t.} \quad \frac{1}{2} \Delta\theta^T F \Delta\theta \leq \delta\]

The solution is:

\[\Delta\theta^* = \sqrt{\frac{2\delta}{g^T F^{-1} g}} F^{-1} g\]

4.3 Efficient Fisher-Vector Product Computation

No need to explicitly construct the \(F\) matrix. Using automatic differentiation:

  1. Compute \(v_1 = \nabla_\theta D_{KL}(\pi_{\theta_{old}} \| \pi_\theta)\)
  2. Compute \(Fv = \nabla_\theta (v_1^T v)\)

This requires only two backward passes, with space complexity \(O(n)\) where \(n\) is the number of parameters.


5. From TRPO to PPO

5.1 TRPO's Limitations

  1. Implementation complexity: Conjugate gradient + line search is cumbersome
  2. Computational expense: Multiple forward/backward passes per step
  3. Incompatible with shared architectures: Difficult to combine with shared-parameter Actor-Critic networks
  4. Incompatible with dropout/noise: Constrained optimization conflicts with stochastic regularization

5.2 PPO's Clipped Alternative

PPO replaces the hard constraint with a clipped objective:

\[L^{CLIP}(\theta) = \mathbb{E} \left[ \min\left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right]\]

where \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\).

Intuition: - When \(\hat{A}_t > 0\) (good action): Limit \(r_t\) to not exceed \(1+\epsilon\), preventing excessive probability increase - When \(\hat{A}_t < 0\) (bad action): Limit \(r_t\) to not go below \(1-\epsilon\), preventing excessive probability decrease

5.3 TRPO vs PPO Comparison

Dimension TRPO PPO (Clip)
Constraint type Hard constraint (KL ≤ δ) Soft constraint (clipping)
Optimization method Conjugate gradient + line search Standard SGD / Adam
Implementation difficulty High Low
Computational cost High Low
Theoretical guarantee Yes (monotonic improvement) Weak (empirically effective)
Practical performance Good Good (often better)
Multi-epoch updates Not applicable Applicable
Industrial applications Few Widespread (RLHF, etc.)

Practical Recommendation

In the vast majority of scenarios, PPO is recommended over TRPO. PPO is simpler, faster, and usually performs better.

TRPO's value lies in its theoretical insights: understanding why policy constraints matter and the geometric meaning of natural gradients.

5.4 PPO-KL Variant

PPO also has a variant using adaptive KL penalty:

\[L^{KL}(\theta) = \mathbb{E} \left[ r_t(\theta) \hat{A}_t - \beta D_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \right]\]

where \(\beta\) is adaptively adjusted based on actual KL divergence: - If \(D_{KL} > d_{targ} \times 1.5\): Increase \(\beta\) - If \(D_{KL} < d_{targ} / 1.5\): Decrease \(\beta\)


6. Implementation Details

6.1 Conjugate Gradient Implementation

def conjugate_gradient(Fvp_func, g, num_steps=10, residual_tol=1e-10):
    """
    Solve Fx = g, where Fvp_func computes F @ v
    """
    x = torch.zeros_like(g)
    r = g.clone()
    p = g.clone()
    rdotr = torch.dot(r, r)

    for i in range(num_steps):
        Fp = Fvp_func(p)
        alpha = rdotr / (torch.dot(p, Fp) + 1e-8)
        x += alpha * p
        r -= alpha * Fp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr

    return x

6.2 Fisher-Vector Product

def fisher_vector_product(policy, states, v, damping=0.1):
    """
    Compute F @ v, where F is the Fisher information matrix
    """
    kl = compute_kl(policy, old_policy, states)
    kl_grad = torch.autograd.grad(kl, policy.parameters(),
                                   create_graph=True)
    kl_grad_flat = torch.cat([g.view(-1) for g in kl_grad])

    kl_v = torch.dot(kl_grad_flat, v)
    kl_v_grad = torch.autograd.grad(kl_v, policy.parameters())
    fvp = torch.cat([g.view(-1) for g in kl_v_grad])

    return fvp + damping * v  # Add damping for numerical stability

6.3 Hyperparameter Recommendations

Hyperparameter Typical Value Description
\(\delta\) 0.01 KL divergence constraint upper bound
CG steps 10 Conjugate gradient iterations
Line search backtrack coeff 0.5 Halve step size each time
Max line search steps 10 Maximum attempts
GAE \(\lambda\) 0.97 Advantage estimation parameter
\(\gamma\) 0.99 Discount factor
Damping coefficient 0.1 FVP damping

7. Theoretical Significance and Historical Position

7.1 TRPO's Contributions

  1. Monotonic improvement guarantee: First theoretical guarantee for deep policy optimization
  2. Practical natural gradient: Made natural gradient scalable through CG and FVP
  3. Trust region paradigm: Inspired PPO and a vast body of subsequent work
  4. Policy-space thinking: Shifted optimization perspective from parameter space to policy space

7.2 Development Lineage

Natural Gradient (Amari, 1998)
  → Natural Actor-Critic (Peters & Schaal, 2008)
    → TRPO (Schulman et al., 2015)
      → PPO (Schulman et al., 2017)
        → RLHF/PPO (Ouyang et al., 2022)
          → GRPO (Shao et al., 2024)

References

  • Amari, S. (1998). Natural Gradient Works Efficiently in Learning. Neural Computation.
  • Kakade, S. (2001). A Natural Policy Gradient. NeurIPS.
  • Kakade, S. & Langford, J. (2002). Approximately Optimal Approximate Reinforcement Learning. ICML.
  • Schulman, J. et al. (2015). Trust Region Policy Optimization. ICML.
  • Schulman, J. et al. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation.
  • Schulman, J. et al. (2017). Proximal Policy Optimization Algorithms.

Further Reading


评论 #