Skip to content

Representation Learning and RL

Motivation

In RL from pixels or high-dimensional sensor inputs, the quality of state representations directly determines learning efficiency and final performance. Good representations should:

  • Capture information relevant to decision-making
  • Ignore irrelevant visual details (e.g., background textures)
  • Generalize well
  • Support efficient policy and value function learning

Data Augmentation Methods

DrQ (Data-regularized Q)

Kostrikov et al. (2020) found that simple image data augmentation can significantly improve pixel-based RL performance.

Core method: Apply random shifts to observation images

\[\tilde{s} = \text{RandomShift}(s, \text{pad}=4)\]

Application in Q-learning:

\[Q_{\text{target}} = r + \gamma Q_{\bar{\theta}}(\text{aug}(s'), \pi(\text{aug}(s')))\]

Augmented observations are used for both target computation and policy evaluation.

DrQ-v2: Yarats et al. (2022) improved version combining:

  • Random shift augmentation
  • \(n\)-step returns
  • Starting from DDPG rather than SAC (simpler and more efficient)

RAD (Reinforcement Learning with Augmented Data)

Laskin et al. (2020) systematically studied the effects of various data augmentation methods in RL:

Augmentation Description Effectiveness
Random crop Random cropping Most effective
Random shift Random translation Very effective
Color jitter Color perturbation Partially effective
Random convolution Random convolutional filters Effective
Grayscale Convert to grayscale Task-dependent
Cutout Random masking Partially effective

Key findings:

  • Simple crop/shift augmentations can match or surpass complex representation learning methods
  • Different environments benefit from different augmentation strategies
  • Data augmentation is the simplest and most effective way to improve pixel-based RL performance

Contrastive Learning Methods

CURL (Contrastive Unsupervised Representations for RL)

Laskin et al. (2020) introduced contrastive learning for RL representation learning:

Positive pairs: Two different augmentations of the same observation

\[s_q = \text{aug}_1(s), \quad s_k = \text{aug}_2(s)\]

Contrastive loss (InfoNCE):

\[\mathcal{L}_{\text{CURL}} = -\log \frac{\exp(q^T W k_+)}{\exp(q^T W k_+) + \sum_{j} \exp(q^T W k_j^-)}\]

where \(q = f_q(s_q)\), \(k = f_k(s_k)\), and \(W\) is a learnable bilinear matrix.

Integration with RL:

  • Encoder \(f_q\) doubles as the feature extractor for policy/value networks
  • Contrastive loss serves as an auxiliary task, trained jointly with the RL loss
  • Momentum encoder \(f_k\) is updated via EMA

Limitations of Contrastive Learning

  • Negative sample selection affects performance
  • Contrastive objectives may not align with RL objectives
  • Increased computational overhead

Bisimulation Metrics

Core Idea

Two states should have similar representations if they are behaviorally indistinguishable.

Bisimulation relation: States \(s_1\) and \(s_2\) are bisimilar if:

  1. They have identical immediate rewards: \(R(s_1, a) = R(s_2, a), \forall a\)
  2. Their transition distributions are identical over bisimulation equivalence classes

Policy Bisimulation Metric

Zhang et al. (2021) proposed a policy-specific bisimulation metric:

\[d_\pi(s_1, s_2) = (1 - c) |R^\pi(s_1) - R^\pi(s_2)| + c \cdot W_1(d_\pi)(P^\pi(\cdot|s_1), P^\pi(\cdot|s_2))\]

where \(W_1\) is the Wasserstein distance and \(c\) is a discount factor.

DeepMDP

Gelada et al. (2019) learn representations \(\phi\) satisfying:

\[\|\phi(s_1) - \phi(s_2)\| \approx d(s_1, s_2)\]

Training objectives:

  • Reward prediction: \(\hat{R}(\phi(s), a) \approx R(s, a)\)
  • Transition prediction: \(\hat{P}(\phi(s), a) \approx \phi(s')\)

Advantages

  • Theoretically guaranteed representation quality
  • Automatically ignores decision-irrelevant information (e.g., background changes)
  • Suitable for scenarios requiring generalization across visual appearances

World Model Representations

Dreamer's Latent Space

Dreamer (Hafner et al., 2020) learns compressed latent state representations:

RSSM (Recurrent State-Space Model):

  • Deterministic path: \(h_t = f(h_{t-1}, z_{t-1}, a_{t-1})\)
  • Stochastic path: \(z_t \sim q(z_t | h_t, o_t)\) (posterior)
  • Prior: \(\hat{z}_t \sim p(z_t | h_t)\)

Latent space properties:

  • Captures information relevant to predicting future rewards and observations
  • Compresses high-dimensional observations into a low-dimensional latent space
  • Supports imagination-based planning in the latent space

Connection to Representation Learning

World model representation learning uses reconstruction objectives and prediction objectives to learn useful representations:

\[\mathcal{L} = \mathcal{L}_{\text{recon}} + \mathcal{L}_{\text{reward}} + \mathcal{L}_{\text{KL}}\]

See Model-Based RL for details.

Self-Predictive Representations (SPR)

Core Idea

Schwarzer et al. (2021) proposed SPR (Self-Predictive Representations), learning by predicting one's own future representations:

\[\mathcal{L}_{\text{SPR}} = \sum_{k=1}^{K} \left\| \bar{f}(\phi(s_{t+k})) - g(\hat{z}_{t+k}) \right\|^2\]

where:

  • \(\phi(s_{t+k})\): Target encoder's encoding of the future state (updated via EMA)
  • \(\hat{z}_{t+k}\): Future representation predicted from the current state via a transition model
  • \(g\): Projection head
  • \(\bar{f}\): Target projection head

Relationship to Other Methods

Method Prediction Target Learning Signal
Reconstruction Raw pixels Pixel-level error
CURL Augmentations of same observation Contrastive loss
SPR Future representations Prediction loss
Bisimulation Behavioral equivalence Metric distance

Advantages

  • No pixel reconstruction needed (avoids pixel-level details)
  • No negative samples needed (a challenge for contrastive methods)
  • Naturally compatible with temporal difference learning

Summary and Selection Guide

Complexity-Performance Tradeoff

Method complexity (simple to complex):
Data augmentation < Contrastive learning < Self-prediction < Bisimulation < World models

Recommended path:
1. First try data augmentation (DrQ-v2)
2. If insufficient, add contrastive/self-predictive auxiliary tasks
3. If generalization is needed, consider bisimulation metrics
4. If planning is needed, use world models

Practical Recommendations

Scenario Recommended Method
Pixel input, rapid prototyping DrQ-v2
Need sample efficiency CURL / SPR
Visual generalization (background changes) Bisimulation metrics
Need model predictions Dreamer
Discrete actions (Atari) SPR + Rainbow
Continuous control DrQ-v2

References

  • Kostrikov et al., "Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels" (ICLR 2021)
  • Yarats et al., "Mastering Visual Continuous Control: Improved Data-Augmented Reinforcement Learning" (ICLR 2022)
  • Laskin et al., "Reinforcement Learning with Augmented Data" (NeurIPS 2020)
  • Laskin et al., "CURL: Contrastive Unsupervised Representations for Reinforcement Learning" (ICML 2020)
  • Zhang et al., "Learning Invariant Representations for Reinforcement Learning without Reconstruction" (ICLR 2021)
  • Schwarzer et al., "Data-Efficient Reinforcement Learning with Self-Predictive Representations" (ICLR 2021)
  • Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination" (ICLR 2020)
  • Gelada et al., "DeepMDP: Learning Continuous Latent Space Models for Representation Learning" (ICML 2019)

评论 #