元强化学习
概述
元强化学习(Meta-RL)旨在学习如何学习——通过在大量相关任务上训练,使智能体获得快速适应新任务的能力。
问题设定
给定任务分布 \(p(\mathcal{T})\),每个任务 \(\mathcal{T}_i\) 是一个MDP \((\mathcal{S}, \mathcal{A}, P_i, R_i, \gamma)\),其中转移动力学或奖励函数可能不同。
目标:学习一个元策略/元学习器,使其在面对新任务时能够用极少的交互快速适应。
与标准RL的区别
| 特性 | 标准RL | Meta-RL |
|---|---|---|
| 训练 | 单个任务 | 任务分布 |
| 目标 | 单任务最优 | 跨任务快速适应 |
| 泛化 | 状态空间内 | 任务空间内 |
| 样本需求 | 大量(每个任务) | 少量(新任务) |
RL²:学习强化学习
核心思想
Duan et al. (2016) 和 Wang et al. (2016) 同时提出:将整个RL算法编码在RNN的权重中。
架构
将多个episode串联成一个长序列,用RNN处理:
其中:
- \(h_t\):RNN隐状态,编码了任务信息
- \(d_{t-1}\):上一步的终止标志
- \(\theta\):元参数,通过大量任务上的RL训练
关键洞察
- RNN的隐状态隐式地进行任务推断——通过观察奖励和转移来推断当前任务
- 整个"学习算法"被编码在RNN的前向传播中
- 元训练时使用外层RL(如PPO),内层"学习"通过RNN的隐状态实现
训练过程
元训练循环:
采样任务 T ~ p(T)
重置RNN隐状态 h₀
运行K个episode(隐状态持续传递):
for episode k = 1 to K:
for step t = 1 to H:
a_t = π(h_t)
s_{t+1}, r_t = env.step(a_t)
h_{t+1} = RNN(h_t, s_t, a_t, r_t)
使用所有episode的回报更新元参数θ
局限性
- 受限于RNN的记忆容量
- 元训练需要大量计算
- 泛化到与训练分布差距大的任务时表现下降
MAML for RL
核心思想
Finn et al. (2017) 将MAML(Model-Agnostic Meta-Learning)应用于RL,学习一组好的初始参数,使得少量梯度更新就能适应新任务。
双层优化
内层更新(任务适应):
对每个任务 \(\mathcal{T}_i\),从元参数 \(\theta\) 出发进行一步(或少数几步)策略梯度更新。
外层更新(元优化):
在适应后的参数 \(\theta'_i\) 上评估性能,优化元参数 \(\theta\)。
算法流程
- 初始化元参数 \(\theta\)
- 采样一批任务 \(\{\mathcal{T}_i\}\)
- 对每个任务: a. 采集少量轨迹 b. 计算策略梯度 c. 执行内层更新得到 \(\theta'_i\)
- 使用 \(\theta'_i\) 采集新轨迹
- 计算外层梯度并更新 \(\theta\)
变体
- MAML + TRPO:使用TRPO作为内外层优化器
- ProMP:使用概率推断视角改进MAML
- E-MAML:考虑探索的元学习
优缺点
优点:
- 模型无关——适用于任何可微分策略
- 理论优雅——学到了"好的起始点"
- 适应过程可解释——就是梯度下降
缺点:
- 需要计算二阶梯度(Hessian-vector product)
- 内层更新步数受限
- 对内层学习率敏感
基于上下文的方法:PEARL
动机
Rakelly et al. (2019) 提出PEARL,通过概率推断的方式进行任务推断,避免了MAML的梯度通过问题。
架构
上下文编码器:从少量经验中推断任务表示
其中上下文 \(c = \{(s_j, a_j, r_j, s'_j)\}_{j=1}^N\) 是少量交互经验。
条件策略:以任务表示为条件
条件价值函数:
概率框架
使用变分推断框架:
假设后验是因子化的(乘积形式),使得可以随着新经验的到来逐步更新任务表示。
训练目标
结合RL目标和变分推断:
与RL²和MAML的对比
| 方法 | 适应机制 | 任务推断 | 离策略训练 |
|---|---|---|---|
| RL² | RNN隐状态 | 隐式 | 否 |
| MAML | 梯度更新 | 通过梯度 | 否 |
| PEARL | 概率推断 | 显式后验 | 是 |
PEARL的关键优势是支持离策略训练(使用SAC),大幅提高样本效率。
任务推断
显式任务推断
学习一个任务推断模型 \(p(z | \tau_{1:t})\),从历史轨迹中推断任务身份或参数:
- 贝叶斯方法:维护任务参数的后验分布
- 神经网络方法:训练编码器直接输出任务表示
隐式任务推断
通过模型结构隐式进行任务推断:
- RL²的RNN隐状态
- Transformer的注意力机制
- 记忆增强网络
少样本适应
适应效率度量
Meta-RL的关键指标是适应效率——在新任务上使用K个episode后的表现:
理想的meta-RL方法应在K很小时就达到高性能。
Zero-Shot vs Few-Shot
- Zero-Shot:不需要任何新任务经验(通过任务描述或先验推断)
- One-Shot:仅需一个episode
- Few-Shot:需要少数几个episode
实践建议
| 场景 | 推荐方法 |
|---|---|
| 任务结构简单 | RL² |
| 需要快速适应 | MAML |
| 需要高样本效率 | PEARL |
| 连续控制任务 | PEARL + SAC |
| 离散动作任务 | RL² + PPO |
参考文献
- Duan et al., "RL²: Fast Reinforcement Learning via Slow Reinforcement Learning" (2016)
- Wang et al., "Learning to Reinforcement Learn" (2016)
- Finn et al., "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks" (ICML 2017)
- Rakelly et al., "Efficient Off-Policy Meta-Reinforcement Learning via Probabilistic Context Variables" (ICML 2019)
- Rothfuss et al., "ProMP: Proximal Meta-Policy Search" (ICLR 2019)
- Beck et al., "A Survey of Meta-Reinforcement Learning" (2023)