跳转至

元强化学习

概述

元强化学习(Meta-RL)旨在学习如何学习——通过在大量相关任务上训练,使智能体获得快速适应新任务的能力。

问题设定

给定任务分布 \(p(\mathcal{T})\),每个任务 \(\mathcal{T}_i\) 是一个MDP \((\mathcal{S}, \mathcal{A}, P_i, R_i, \gamma)\),其中转移动力学或奖励函数可能不同。

目标:学习一个元策略/元学习器,使其在面对新任务时能够用极少的交互快速适应。

与标准RL的区别

特性 标准RL Meta-RL
训练 单个任务 任务分布
目标 单任务最优 跨任务快速适应
泛化 状态空间内 任务空间内
样本需求 大量(每个任务) 少量(新任务)

RL²:学习强化学习

核心思想

Duan et al. (2016) 和 Wang et al. (2016) 同时提出:将整个RL算法编码在RNN的权重中。

架构

将多个episode串联成一个长序列,用RNN处理:

\[h_t = f_\theta(h_{t-1}, s_t, a_{t-1}, r_{t-1}, d_{t-1})\]
\[a_t \sim \pi_\theta(\cdot | h_t)\]

其中:

  • \(h_t\):RNN隐状态,编码了任务信息
  • \(d_{t-1}\):上一步的终止标志
  • \(\theta\):元参数,通过大量任务上的RL训练

关键洞察

  • RNN的隐状态隐式地进行任务推断——通过观察奖励和转移来推断当前任务
  • 整个"学习算法"被编码在RNN的前向传播中
  • 元训练时使用外层RL(如PPO),内层"学习"通过RNN的隐状态实现

训练过程

元训练循环:
    采样任务 T ~ p(T)
    重置RNN隐状态 h₀
    运行K个episode(隐状态持续传递):
        for episode k = 1 to K:
            for step t = 1 to H:
                a_t = π(h_t)
                s_{t+1}, r_t = env.step(a_t)
                h_{t+1} = RNN(h_t, s_t, a_t, r_t)
    使用所有episode的回报更新元参数θ

局限性

  • 受限于RNN的记忆容量
  • 元训练需要大量计算
  • 泛化到与训练分布差距大的任务时表现下降

MAML for RL

核心思想

Finn et al. (2017) 将MAML(Model-Agnostic Meta-Learning)应用于RL,学习一组好的初始参数,使得少量梯度更新就能适应新任务。

双层优化

内层更新(任务适应):

\[\theta'_i = \theta + \alpha \nabla_\theta J_{\mathcal{T}_i}(\pi_\theta)\]

对每个任务 \(\mathcal{T}_i\),从元参数 \(\theta\) 出发进行一步(或少数几步)策略梯度更新。

外层更新(元优化):

\[\theta \leftarrow \theta + \beta \sum_{\mathcal{T}_i} \nabla_\theta J_{\mathcal{T}_i}(\pi_{\theta'_i})\]

在适应后的参数 \(\theta'_i\) 上评估性能,优化元参数 \(\theta\)

算法流程

  1. 初始化元参数 \(\theta\)
  2. 采样一批任务 \(\{\mathcal{T}_i\}\)
  3. 对每个任务: a. 采集少量轨迹 b. 计算策略梯度 c. 执行内层更新得到 \(\theta'_i\)
  4. 使用 \(\theta'_i\) 采集新轨迹
  5. 计算外层梯度并更新 \(\theta\)

变体

  • MAML + TRPO:使用TRPO作为内外层优化器
  • ProMP:使用概率推断视角改进MAML
  • E-MAML:考虑探索的元学习

优缺点

优点

  • 模型无关——适用于任何可微分策略
  • 理论优雅——学到了"好的起始点"
  • 适应过程可解释——就是梯度下降

缺点

  • 需要计算二阶梯度(Hessian-vector product)
  • 内层更新步数受限
  • 对内层学习率敏感

基于上下文的方法:PEARL

动机

Rakelly et al. (2019) 提出PEARL,通过概率推断的方式进行任务推断,避免了MAML的梯度通过问题。

架构

上下文编码器:从少量经验中推断任务表示

\[z \sim q_\phi(z | c)\]

其中上下文 \(c = \{(s_j, a_j, r_j, s'_j)\}_{j=1}^N\) 是少量交互经验。

条件策略:以任务表示为条件

\[a = \pi_\theta(s, z)\]

条件价值函数

\[Q_\psi(s, a, z)\]

概率框架

使用变分推断框架:

\[q_\phi(z | c) = \prod_{j=1}^N q_\phi(z | s_j, a_j, r_j, s'_j)\]

假设后验是因子化的(乘积形式),使得可以随着新经验的到来逐步更新任务表示。

训练目标

结合RL目标和变分推断:

\[\max_{\theta, \psi, \phi} \; \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})} \left[\mathbb{E}_{z \sim q_\phi(z|c)} [J(\pi_\theta(\cdot, z))] - \beta D_{\text{KL}}(q_\phi(z|c) \| p(z))\right]\]

与RL²和MAML的对比

方法 适应机制 任务推断 离策略训练
RL² RNN隐状态 隐式
MAML 梯度更新 通过梯度
PEARL 概率推断 显式后验

PEARL的关键优势是支持离策略训练(使用SAC),大幅提高样本效率。

任务推断

显式任务推断

学习一个任务推断模型 \(p(z | \tau_{1:t})\),从历史轨迹中推断任务身份或参数:

  • 贝叶斯方法:维护任务参数的后验分布
  • 神经网络方法:训练编码器直接输出任务表示

隐式任务推断

通过模型结构隐式进行任务推断:

  • RL²的RNN隐状态
  • Transformer的注意力机制
  • 记忆增强网络

少样本适应

适应效率度量

Meta-RL的关键指标是适应效率——在新任务上使用K个episode后的表现:

\[\text{Performance}(K) = \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})} [J(\pi_{\text{adapted}}^K)]\]

理想的meta-RL方法应在K很小时就达到高性能。

Zero-Shot vs Few-Shot

  • Zero-Shot:不需要任何新任务经验(通过任务描述或先验推断)
  • One-Shot:仅需一个episode
  • Few-Shot:需要少数几个episode

实践建议

场景 推荐方法
任务结构简单 RL²
需要快速适应 MAML
需要高样本效率 PEARL
连续控制任务 PEARL + SAC
离散动作任务 RL² + PPO

参考文献

  • Duan et al., "RL²: Fast Reinforcement Learning via Slow Reinforcement Learning" (2016)
  • Wang et al., "Learning to Reinforcement Learn" (2016)
  • Finn et al., "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks" (ICML 2017)
  • Rakelly et al., "Efficient Off-Policy Meta-Reinforcement Learning via Probabilistic Context Variables" (ICML 2019)
  • Rothfuss et al., "ProMP: Proximal Meta-Policy Search" (ICLR 2019)
  • Beck et al., "A Survey of Meta-Reinforcement Learning" (2023)

评论 #