GRU
Gated Recurrent Unit(门控循环单元),由 Cho et al. 于2014年提出,是LSTM的简化变体。它将LSTM的3个门简化为2个门,并去掉了独立的细胞状态,用更少的参数实现了与LSTM相当的性能,成为RNN家族中另一个重要的门控架构。
背景与动机
LSTM 的问题
LSTM通过引入遗忘门、输入门、输出门和独立的细胞状态 \(C_t\),成功解决了梯度消失问题。但它也带来了新的代价:
- 参数量大:4组权重矩阵(遗忘门、输入门、候选值、输出门),参数量约为 Vanilla RNN 的4倍
- 计算开销高:每个时间步需要4次矩阵乘法 + 多次激活函数计算
- 结构复杂:两个状态(\(h_t\) 和 \(C_t\))需要同时维护,理解和调试都更困难
核心问题
LSTM 的3个门和独立的细胞状态是否都必要?能否用更简单的结构达到类似效果?
GRU 的核心思想
Cho et al. (2014) 在论文 "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" 中提出了GRU,其核心简化策略:
- 合并遗忘门和输入门为一个更新门(Update Gate)\(z_t\):用 \(z_t\) 同时控制"遗忘多少旧信息"和"接受多少新信息",二者互补(一个控制量决定两件事)
- 去掉独立的细胞状态:只保留隐藏状态 \(h_t\),让它直接充当信息载体
- 引入重置门(Reset Gate)\(r_t\):控制生成候选隐藏状态时参考多少历史信息
结果:2个门 + 1个状态,参数量比LSTM少约25%,训练更快,性能相当。
GRU 架构详解
整体结构图
┌─────────────────────────────────────────────────┐
│ GRU Cell │
│ │
h_{t-1} ──┼──┬──────────────────┬───────────────┐ │
│ │ │ │ │
x_t ──────┼──┤ │ │ │
│ │ │ │ │
│ ▼ ▼ │ │
│ ┌──────┐ ┌──────┐ │ │
│ │Reset │ │Update│ │ │
│ │ Gate │ │ Gate │ │ │
│ │ r_t │ │ z_t │ │ │
│ │ σ │ │ σ │ │ │
│ └──┬───┘ └──┬───┘ │ │
│ │ │ │ │
│ ▼ │ │ │
│ r_t ⊙ h_{t-1} │ │ │
│ │ │ │ │
│ ▼ │ │ │
│ ┌────────┐ │ │ │
│ │Candidate│ │ │ │
│ │ h~_t │ │ │ │
│ │ tanh │ │ │ │
│ └───┬────┘ │ │ │
│ │ │ │ │
│ │ ▼ ▼ │
│ │ (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h~_t │
│ │ │ │
│ │ ▼ │
│ │ h_t ──────────────────────────────┼──→ h_t
│ │ │
└─────┼───────────────────────────────────────────┘
│
与LSTM最大的区别:没有独立的细胞状态 \(C_t\),只有一个隐藏状态 \(h_t\) 承担所有职责。
两个门
重置门(Reset Gate)
- 输出范围:\((0, 1)\),逐元素控制
- 作用:决定在生成候选隐藏状态时,"忽略多少历史信息"
- \(r_t \approx 1\):完全保留历史信息(候选状态参考完整的 \(h_{t-1}\))
- \(r_t \approx 0\):忽略历史信息(候选状态几乎只由当前输入 \(x_t\) 决定)
直觉理解
重置门让GRU能够"忘记"不相关的历史。例如在处理文本时,当遇到句号开始新句子,重置门可以输出接近0的值,让模型"重新开始",不受前一句话的干扰。
更新门(Update Gate)
- 输出范围:\((0, 1)\),逐元素控制
- 作用:决定最终隐藏状态中,"保留多少旧状态 vs 接受多少新状态"
- \(z_t \approx 1\):几乎完全采用候选新状态(接受新信息)
- \(z_t \approx 0\):几乎完全保留旧状态(保持记忆不变)
注意符号约定
不同文献中 \(z_t\) 的含义可能相反。本文采用的约定是:\(z_t\) 越大,越倾向于采用新的候选状态。有些文献中 \(z_t\) 大表示保留旧状态。阅读论文时请注意区分。
候选隐藏状态
这里的关键是 \(r_t \odot h_{t-1}\):重置门先对历史隐藏状态进行"过滤",再与当前输入 \(x_t\) 拼接,经过线性变换和 \(\tanh\) 激活生成候选状态。
最终隐藏状态更新
这是GRU最核心的公式。更新门 \(z_t\) 同时控制了两件事:
- \((1 - z_t) \odot h_{t-1}\):保留旧状态的比例
- \(z_t \odot \tilde{h}_t\):接受新状态的比例
二者之和的权重恰好为1(互补),这是一个凸组合,保证了数值稳定性。
与LSTM的关键区别
LSTM的遗忘门 \(f_t\) 和输入门 \(i_t\) 是独立的两个门,可以同时全开或全关(\(f_t \approx 1\) 且 \(i_t \approx 1\)),这意味着LSTM可以同时保留旧信息并大量写入新信息。而GRU的 \((1-z_t)\) 和 \(z_t\) 是互补的,保留和更新之间必须做权衡。这是GRU简化带来的表达能力上的微小损失。
门控机制直觉理解
更新门 \(z_t\):LSTM两个门的合体
LSTM中,遗忘门 \(f_t\) 和输入门 \(i_t\) 分别控制"丢弃旧信息"和"写入新信息"。GRU的更新门 \(z_t\) 用一个参数同时控制了这两件事:
- LSTM:\(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\)(\(f_t\) 和 \(i_t\) 独立)
- GRU:\(h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\)(互补约束)
重置门 \(r_t\):灵活的历史遗忘
重置门在候选状态生成阶段发挥作用,而不是在最终状态更新阶段。当 \(r_t \approx 0\) 时:
候选状态完全不参考历史,就像一个"全新的开始"。这使GRU能够在需要时彻底切换上下文。
与LSTM门的对应关系
| GRU | LSTM | 功能 |
|---|---|---|
| 更新门 \(z_t\) | 遗忘门 \(f_t\) + 输入门 \(i_t\) | 控制信息的保留与更新比例 |
| 重置门 \(r_t\) | (无直接对应) | 控制候选状态对历史的依赖程度 |
| (无) | 输出门 \(o_t\) | GRU直接输出 \(h_t\),不做额外过滤 |
| 隐藏状态 \(h_t\) | 细胞状态 \(C_t\) + 隐藏状态 \(h_t\) | GRU只有一个状态,身兼两职 |
前向传播:完整数值示例
设定:input_size = 2,hidden_size = 3,处理2个时间步。
初始化
假设权重矩阵(每个 \(W \in \mathbb{R}^{3 \times 5}\),因为 \([h_{t-1}, x_t] \in \mathbb{R}^{3+2=5}\)):
初始隐藏状态 \(h_0 = [0, 0, 0]\)。
时间步 1:\(x_1 = [1.0, 0.5]\)
拼接输入:\([h_0, x_1] = [0, 0, 0, 1.0, 0.5]\)
计算更新门:
计算重置门:
计算候选隐藏状态(因为 \(h_0 = \mathbf{0}\),\(r_1 \odot h_0 = \mathbf{0}\)):
计算最终隐藏状态:
时间步 2:\(x_2 = [-0.5, 0.8]\)
拼接输入:\([h_1, x_2] = [0.0, 0.168, 0.025, -0.5, 0.8]\)
计算更新门:
计算重置门:
计算候选隐藏状态(\(r_2 \odot h_1 = [0.452 \times 0.0,\ 0.559 \times 0.168,\ 0.561 \times 0.025] = [0.0, 0.094, 0.014]\)):
计算最终隐藏状态:
观察
即使在这个简单的2步示例中,也可以看到GRU的状态在不断更新:更新门决定了新旧信息的混合比例,重置门影响了候选状态对历史的依赖。整个过程只需要3组矩阵乘法(对比LSTM的4组),计算确实更精简。
反向传播(BPTT)与梯度流动
梯度通过更新门的线性路径
GRU的状态更新公式:
计算 \(\frac{\partial h_t}{\partial h_{t-1}}\) 时,由于 \(h_{t-1}\) 出现在多处(更新门、重置门、候选状态中都有),完整的梯度表达式较复杂。但最重要的一条梯度路径是直接路径:
这条路径来自 \((1 - z_t) \odot h_{t-1}\) 这一项,梯度只需乘以 \((1 - z_t)\),不经过任何矩阵乘法或挤压性激活函数。
为什么 GRU 也能缓解梯度消失
类比LSTM的细胞状态"高速公路":
| LSTM | GRU | |
|---|---|---|
| 信息高速公路 | \(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\) | \(h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\) |
| 直接梯度路径 | \(\frac{\partial C_t}{\partial C_{t-1}} = f_t\) | \(\frac{\partial h_t}{\partial h_{t-1}}\big\|_{\text{direct}} = 1 - z_t\) |
| 梯度跨 \(T\) 步传播 | \(\prod_{t} f_t\) | \(\prod_{t} (1 - z_t)\) |
| 保持梯度的条件 | \(f_t \approx 1\)(遗忘门接近1) | \(z_t \approx 0\)(更新门接近0) |
当模型学到某段信息需要长期保留时:
- LSTM:遗忘门 \(f_t \approx 1\),信息在细胞状态中几乎不变地传递
- GRU:更新门 \(z_t \approx 0\),隐藏状态 \(h_t \approx h_{t-1}\),信息直接"穿过"
两者的机制本质上是等价的:都提供了一条加性的、可控的梯度传播路径,避免了Vanilla RNN中梯度必须经过 \(W_{hh}\) 矩阵乘法的问题。
加性结构是关键
无论LSTM还是GRU,缓解梯度消失的核心都是加法结构:新状态 = 旧状态的加权 + 新信息的加权。加法的梯度是简单的系数,不涉及矩阵乘法的特征值问题。这与ResNet的残差连接 \(y = x + F(x)\) 思想完全一致。
GRU vs LSTM 详细对比
| 对比维度 | GRU | LSTM |
|---|---|---|
| 提出时间 | 2014 (Cho et al.) | 1997 (Hochreiter & Schmidhuber) |
| 门的数量 | 2个(更新门、重置门) | 3个(遗忘门、输入门、输出门) |
| 状态变量 | 仅 \(h_t\) | \(h_t\) 和 \(C_t\) |
| 参数量 | \(3 d_h(d_h + d) + 3d_h\) | \(4 d_h(d_h + d) + 4d_h\) |
| 参数比例 | 约为LSTM的 75% | 基准 (100%) |
| 训练速度 | 较快(少一组矩阵运算) | 较慢 |
| 信息保留/更新 | 互补约束(\(z_t\) 和 \(1-z_t\)) | 独立控制(\(f_t\) 和 \(i_t\) 分离) |
| 输出过滤 | 无(直接输出 \(h_t\)) | 有(输出门 \(o_t\) 过滤) |
| 表达能力 | 略弱(互补约束限制) | 略强(独立门控更灵活) |
| 大多数任务性能 | 与LSTM持平 | 与GRU持平 |
| 小数据集 | 可能更优(参数少,不易过拟合) | 可能过拟合 |
| 极长序列 | 略逊 | 略优(独立cell state更稳定) |
| 代码复杂度 | 更简单 | 更复杂 |
经验法则
实践中,GRU和LSTM的性能差异通常很小。Jozefowicz et al. (2015) 在大规模实验中发现,没有哪个架构在所有任务上都占优。一般建议:先试GRU(更快),如果效果不够好再换LSTM。对于特别长的序列或需要精细控制信息流的任务,LSTM可能更合适。
代码实现(PyTorch)
基本使用
import torch
import torch.nn as nn
# 参数设定
input_size = 10 # 输入特征维度
hidden_size = 20 # 隐藏状态维度
num_layers = 2 # GRU层数
batch_size = 3 # 批大小
seq_len = 5 # 序列长度
# 创建 GRU
gru = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True, # 输入格式: (batch, seq, feature)
dropout=0.1, # 层间 dropout(仅在 num_layers > 1 时有效)
bidirectional=False # 是否双向
)
# 输入数据
x = torch.randn(batch_size, seq_len, input_size) # (3, 5, 10)
h0 = torch.zeros(num_layers, batch_size, hidden_size) # (2, 3, 20)
# 前向传播
output, h_n = gru(x, h0)
print(f"output shape: {output.shape}") # (3, 5, 20) - 每个时间步的输出
print(f"h_n shape: {h_n.shape}") # (2, 3, 20) - 最后时间步各层的隐藏状态
nn.GRU 参数说明
| 参数 | 类型 | 说明 |
|---|---|---|
input_size |
int | 输入 \(x_t\) 的特征维度 |
hidden_size |
int | 隐藏状态 \(h_t\) 的维度 |
num_layers |
int | 堆叠的GRU层数,默认1 |
bias |
bool | 是否使用偏置,默认True |
batch_first |
bool | True则输入为 (batch, seq, feature),默认False |
dropout |
float | 层间dropout比率(最后一层不加),默认0 |
bidirectional |
bool | 是否双向GRU,默认False |
输出说明
| 输出 | 形状 | 说明 |
|---|---|---|
output |
(batch, seq, hidden_size * num_directions) |
所有时间步最后一层的 \(h_t\) |
h_n |
(num_layers * num_directions, batch, hidden_size) |
最后一个时间步各层的 \(h_t\) |
基于 GRU 的文本分类模型
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_size, num_classes,
num_layers=1, dropout=0.5):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.gru = nn.GRU(
embed_dim, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0,
bidirectional=True
)
self.dropout = nn.Dropout(dropout)
# 双向GRU,隐藏状态维度翻倍
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
# x: (batch, seq_len) - token indices
embeds = self.dropout(self.embedding(x)) # (batch, seq, embed_dim)
output, h_n = self.gru(embeds) # output: (batch, seq, hidden*2)
# 取最后一个时间步的输出(双向拼接)
# 或者用平均池化: out = output.mean(dim=1)
out = output[:, -1, :] # (batch, hidden*2)
out = self.dropout(out)
logits = self.fc(out) # (batch, num_classes)
return logits
# 使用示例
model = GRUClassifier(
vocab_size=10000, embed_dim=128,
hidden_size=256, num_classes=2
)
x = torch.randint(0, 10000, (32, 50)) # batch=32, seq_len=50
logits = model(x) # (32, 2)
实践提示
使用双向GRU时,h_n 的形状为 (num_layers * 2, batch, hidden_size)。如果要取最后一层双向的隐藏状态并拼接,需要 torch.cat([h_n[-2], h_n[-1]], dim=-1),而不是简单地取 h_n[-1]。
思考与讨论
GRU 什么时候优于 LSTM?
- 小数据集:GRU参数少约25%,正则化效果天然更好,不易过拟合
- 计算资源受限:边缘设备、实时推理等场景,GRU的速度优势明显
- 序列不太长:当序列长度在几十到几百的范围内,GRU和LSTM差异极小,此时选择更简洁的GRU更合理
- 快速原型验证:GRU训练更快,适合快速迭代实验
为什么两者性能差不多?
核心原因在于两者解决梯度消失的机制本质相同——都是通过加性更新提供线性梯度路径。LSTM多出的输出门和独立细胞状态提供了更精细的控制能力,但在大多数实际任务中,这种额外的表达能力并不是性能的瓶颈。模型性能更多地受数据质量、超参数调优、正则化策略等因素影响。
GRU 在现代深度学习中的地位
随着Transformer架构的崛起,LSTM和GRU在NLP领域的使用都大幅减少。但GRU仍然活跃在以下场景:
- 时间序列预测:金融数据、传感器数据等,序列长度适中,GRU性价比高
- 语音处理:实时语音识别、语音合成中的流式处理
- 强化学习:部分策略网络使用GRU处理部分可观测环境的历史信息
- 边缘计算/嵌入式设备:模型体积和推理速度至关重要的场景
- 作为大模型的组件:一些混合架构中,GRU被用作局部的序列建模模块
历史意义
GRU的最大贡献不在于性能的提升,而在于证明了一个重要观点:LSTM的复杂性并非必要。这启发了后续研究者不断探索更简洁的序列建模方案,最终促进了Transformer等新架构的诞生。简洁性本身就是一种价值。
总结
GRU的方程汇总(完整的前向传播,共4个公式):
核心要点:
- GRU = LSTM 的简化版:2个门代替3个门,1个状态代替2个状态
- 更新门 \(z_t\) 同时承担了LSTM遗忘门和输入门的角色(互补约束)
- 重置门 \(r_t\) 控制候选状态对历史的依赖程度
- 加性更新结构提供线性梯度路径,缓解梯度消失
- 参数量约为LSTM的75%,训练更快,性能通常持平
下一步:Seq2Seq(编码器-解码器架构,GRU/LSTM均可作为基础组件)