Skip to content

GRU

Gated Recurrent Unit(门控循环单元),由 Cho et al. 于2014年提出,是LSTM的简化变体。它将LSTM的3个门简化为2个门,并去掉了独立的细胞状态,用更少的参数实现了与LSTM相当的性能,成为RNN家族中另一个重要的门控架构。


背景与动机

LSTM 的问题

LSTM通过引入遗忘门、输入门、输出门和独立的细胞状态 \(C_t\),成功解决了梯度消失问题。但它也带来了新的代价:

  • 参数量大:4组权重矩阵(遗忘门、输入门、候选值、输出门),参数量约为 Vanilla RNN 的4倍
  • 计算开销高:每个时间步需要4次矩阵乘法 + 多次激活函数计算
  • 结构复杂:两个状态(\(h_t\)\(C_t\))需要同时维护,理解和调试都更困难

核心问题

LSTM 的3个门和独立的细胞状态是否都必要?能否用更简单的结构达到类似效果?

GRU 的核心思想

Cho et al. (2014) 在论文 "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" 中提出了GRU,其核心简化策略:

  1. 合并遗忘门和输入门为一个更新门(Update Gate)\(z_t\):用 \(z_t\) 同时控制"遗忘多少旧信息"和"接受多少新信息",二者互补(一个控制量决定两件事)
  2. 去掉独立的细胞状态:只保留隐藏状态 \(h_t\),让它直接充当信息载体
  3. 引入重置门(Reset Gate)\(r_t\):控制生成候选隐藏状态时参考多少历史信息

结果:2个门 + 1个状态,参数量比LSTM少约25%,训练更快,性能相当。


GRU 架构详解

整体结构图

            ┌─────────────────────────────────────────────────┐
            │                   GRU Cell                       │
            │                                                  │
  h_{t-1} ──┼──┬──────────────────┬───────────────┐            │
            │  │                  │               │            │
  x_t ──────┼──┤                  │               │            │
            │  │                  │               │            │
            │  ▼                  ▼               │            │
            │ ┌──────┐      ┌──────┐              │            │
            │ │Reset │      │Update│              │            │
            │ │ Gate │      │ Gate │              │            │
            │ │ r_t  │      │ z_t  │              │            │
            │ │  σ   │      │  σ   │              │            │
            │ └──┬───┘      └──┬───┘              │            │
            │    │             │                  │            │
            │    ▼             │                  │            │
            │  r_t ⊙ h_{t-1}  │                  │            │
            │    │             │                  │            │
            │    ▼             │                  │            │
            │ ┌────────┐      │                  │            │
            │ │Candidate│     │                  │            │
            │ │  h~_t   │     │                  │            │
            │ │  tanh   │     │                  │            │
            │ └───┬────┘      │                  │            │
            │     │           │                  │            │
            │     │           ▼                  ▼            │
            │     │     (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h~_t    │
            │     │           │                               │
            │     │           ▼                               │
            │     │         h_t ──────────────────────────────┼──→ h_t
            │     │                                           │
            └─────┼───────────────────────────────────────────┘
                  │

与LSTM最大的区别:没有独立的细胞状态 \(C_t\),只有一个隐藏状态 \(h_t\) 承担所有职责。

两个门

重置门(Reset Gate)

\[ r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \]
  • 输出范围:\((0, 1)\),逐元素控制
  • 作用:决定在生成候选隐藏状态时,"忽略多少历史信息"
  • \(r_t \approx 1\):完全保留历史信息(候选状态参考完整的 \(h_{t-1}\)
  • \(r_t \approx 0\):忽略历史信息(候选状态几乎只由当前输入 \(x_t\) 决定)

直觉理解

重置门让GRU能够"忘记"不相关的历史。例如在处理文本时,当遇到句号开始新句子,重置门可以输出接近0的值,让模型"重新开始",不受前一句话的干扰。

更新门(Update Gate)

\[ z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \]
  • 输出范围:\((0, 1)\),逐元素控制
  • 作用:决定最终隐藏状态中,"保留多少旧状态 vs 接受多少新状态"
  • \(z_t \approx 1\):几乎完全采用候选新状态(接受新信息)
  • \(z_t \approx 0\):几乎完全保留旧状态(保持记忆不变)

注意符号约定

不同文献中 \(z_t\) 的含义可能相反。本文采用的约定是:\(z_t\) 越大,越倾向于采用的候选状态。有些文献中 \(z_t\) 大表示保留旧状态。阅读论文时请注意区分。

候选隐藏状态

\[ \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) \]

这里的关键是 \(r_t \odot h_{t-1}\):重置门先对历史隐藏状态进行"过滤",再与当前输入 \(x_t\) 拼接,经过线性变换和 \(\tanh\) 激活生成候选状态。

最终隐藏状态更新

\[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

这是GRU最核心的公式。更新门 \(z_t\) 同时控制了两件事:

  • \((1 - z_t) \odot h_{t-1}\):保留旧状态的比例
  • \(z_t \odot \tilde{h}_t\):接受新状态的比例

二者之和的权重恰好为1(互补),这是一个凸组合,保证了数值稳定性。

与LSTM的关键区别

LSTM的遗忘门 \(f_t\) 和输入门 \(i_t\)独立的两个门,可以同时全开或全关(\(f_t \approx 1\)\(i_t \approx 1\)),这意味着LSTM可以同时保留旧信息并大量写入新信息。而GRU的 \((1-z_t)\)\(z_t\) 是互补的,保留和更新之间必须做权衡。这是GRU简化带来的表达能力上的微小损失。


门控机制直觉理解

更新门 \(z_t\):LSTM两个门的合体

LSTM中,遗忘门 \(f_t\) 和输入门 \(i_t\) 分别控制"丢弃旧信息"和"写入新信息"。GRU的更新门 \(z_t\) 用一个参数同时控制了这两件事:

  • LSTM:\(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\)\(f_t\)\(i_t\) 独立)
  • GRU:\(h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\)(互补约束)

重置门 \(r_t\):灵活的历史遗忘

重置门在候选状态生成阶段发挥作用,而不是在最终状态更新阶段。当 \(r_t \approx 0\) 时:

\[ \tilde{h}_t = \tanh(W_h \cdot [\mathbf{0}, x_t] + b_h) \]

候选状态完全不参考历史,就像一个"全新的开始"。这使GRU能够在需要时彻底切换上下文。

与LSTM门的对应关系

GRU LSTM 功能
更新门 \(z_t\) 遗忘门 \(f_t\) + 输入门 \(i_t\) 控制信息的保留与更新比例
重置门 \(r_t\) (无直接对应) 控制候选状态对历史的依赖程度
(无) 输出门 \(o_t\) GRU直接输出 \(h_t\),不做额外过滤
隐藏状态 \(h_t\) 细胞状态 \(C_t\) + 隐藏状态 \(h_t\) GRU只有一个状态,身兼两职

前向传播:完整数值示例

设定:input_size = 2hidden_size = 3,处理2个时间步。

初始化

假设权重矩阵(每个 \(W \in \mathbb{R}^{3 \times 5}\),因为 \([h_{t-1}, x_t] \in \mathbb{R}^{3+2=5}\)):

\[ W_z = \begin{bmatrix} 0.1 & 0.2 & -0.1 & 0.3 & 0.1 \\ -0.2 & 0.1 & 0.3 & -0.1 & 0.2 \\ 0.3 & -0.1 & 0.2 & 0.1 & -0.2 \end{bmatrix}, \quad b_z = [0, 0, 0] \]
\[ W_r = \begin{bmatrix} 0.2 & -0.1 & 0.1 & 0.2 & -0.1 \\ 0.1 & 0.3 & -0.2 & 0.1 & 0.3 \\ -0.1 & 0.2 & 0.1 & -0.1 & 0.2 \end{bmatrix}, \quad b_r = [0, 0, 0] \]
\[ W_h = \begin{bmatrix} -0.1 & 0.2 & 0.3 & 0.1 & -0.2 \\ 0.2 & -0.1 & 0.1 & 0.3 & 0.1 \\ 0.1 & 0.1 & -0.2 & -0.1 & 0.3 \end{bmatrix}, \quad b_h = [0, 0, 0] \]

初始隐藏状态 \(h_0 = [0, 0, 0]\)

时间步 1:\(x_1 = [1.0, 0.5]\)

拼接输入\([h_0, x_1] = [0, 0, 0, 1.0, 0.5]\)

计算更新门

\[ z_1 = \sigma(W_z \cdot [h_0, x_1] + b_z) = \sigma([0.3 \times 1.0 + 0.1 \times 0.5,\ -0.1 \times 1.0 + 0.2 \times 0.5,\ 0.1 \times 1.0 + (-0.2) \times 0.5]) \]
\[ = \sigma([0.35, 0.0, 0.0]) = [0.587, 0.500, 0.500] \]

计算重置门

\[ r_1 = \sigma(W_r \cdot [h_0, x_1] + b_r) = \sigma([0.2 \times 1.0 + (-0.1) \times 0.5,\ 0.1 \times 1.0 + 0.3 \times 0.5,\ -0.1 \times 1.0 + 0.2 \times 0.5]) \]
\[ = \sigma([0.15, 0.25, 0.0]) = [0.537, 0.562, 0.500] \]

计算候选隐藏状态(因为 \(h_0 = \mathbf{0}\)\(r_1 \odot h_0 = \mathbf{0}\)):

\[ \tilde{h}_1 = \tanh(W_h \cdot [\mathbf{0}, x_1] + b_h) = \tanh([0.1 \times 1.0 + (-0.2) \times 0.5,\ 0.3 \times 1.0 + 0.1 \times 0.5,\ -0.1 \times 1.0 + 0.3 \times 0.5]) \]
\[ = \tanh([0.0, 0.35, 0.05]) = [0.0, 0.336, 0.050] \]

计算最终隐藏状态

\[ h_1 = (1 - z_1) \odot h_0 + z_1 \odot \tilde{h}_1 = [0, 0, 0] + [0.587, 0.500, 0.500] \odot [0.0, 0.336, 0.050] \]
\[ = [0.0, 0.168, 0.025] \]

时间步 2:\(x_2 = [-0.5, 0.8]\)

拼接输入\([h_1, x_2] = [0.0, 0.168, 0.025, -0.5, 0.8]\)

计算更新门

\[ z_2 = \sigma(W_z \cdot [h_1, x_2] + b_z) \]
\[ = \sigma([0.0 + 0.034 - 0.003 - 0.15 + 0.08,\ 0.0 + 0.017 + 0.008 + 0.05 + 0.16,\ 0.0 - 0.017 + 0.005 - 0.05 - 0.16]) \]
\[ = \sigma([-0.039, 0.235, -0.222]) = [0.490, 0.559, 0.445] \]

计算重置门

\[ r_2 = \sigma(W_r \cdot [h_1, x_2] + b_r) \]
\[ = \sigma([0.0 - 0.017 + 0.003 - 0.1 - 0.08,\ 0.0 + 0.050 - 0.005 - 0.05 + 0.24,\ 0.0 + 0.034 + 0.003 + 0.05 + 0.16]) \]
\[ = \sigma([-0.194, 0.235, 0.247]) = [0.452, 0.559, 0.561] \]

计算候选隐藏状态\(r_2 \odot h_1 = [0.452 \times 0.0,\ 0.559 \times 0.168,\ 0.561 \times 0.025] = [0.0, 0.094, 0.014]\)):

\[ \tilde{h}_2 = \tanh(W_h \cdot [0.0, 0.094, 0.014, -0.5, 0.8] + b_h) \]
\[ = \tanh([0.0 + 0.019 + 0.004 - 0.05 - 0.16,\ 0.0 - 0.009 + 0.001 - 0.15 + 0.08,\ 0.0 + 0.009 - 0.003 + 0.05 + 0.24]) \]
\[ = \tanh([-0.187, -0.078, 0.296]) = [-0.185, -0.078, 0.287] \]

计算最终隐藏状态

\[ h_2 = (1 - z_2) \odot h_1 + z_2 \odot \tilde{h}_2 \]
\[ = [0.510 \times 0.0,\ 0.441 \times 0.168,\ 0.555 \times 0.025] + [0.490 \times (-0.185),\ 0.559 \times (-0.078),\ 0.445 \times 0.287] \]
\[ = [0.0, 0.074, 0.014] + [-0.091, -0.044, 0.128] = [-0.091, 0.030, 0.142] \]

观察

即使在这个简单的2步示例中,也可以看到GRU的状态在不断更新:更新门决定了新旧信息的混合比例,重置门影响了候选状态对历史的依赖。整个过程只需要3组矩阵乘法(对比LSTM的4组),计算确实更精简。


反向传播(BPTT)与梯度流动

梯度通过更新门的线性路径

GRU的状态更新公式:

\[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

计算 \(\frac{\partial h_t}{\partial h_{t-1}}\) 时,由于 \(h_{t-1}\) 出现在多处(更新门、重置门、候选状态中都有),完整的梯度表达式较复杂。但最重要的一条梯度路径是直接路径

\[ \frac{\partial h_t}{\partial h_{t-1}} \bigg|_{\text{direct}} = \text{diag}(1 - z_t) \]

这条路径来自 \((1 - z_t) \odot h_{t-1}\) 这一项,梯度只需乘以 \((1 - z_t)\),不经过任何矩阵乘法或挤压性激活函数。

为什么 GRU 也能缓解梯度消失

类比LSTM的细胞状态"高速公路":

LSTM GRU
信息高速公路 \(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\) \(h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\)
直接梯度路径 \(\frac{\partial C_t}{\partial C_{t-1}} = f_t\) \(\frac{\partial h_t}{\partial h_{t-1}}\big\|_{\text{direct}} = 1 - z_t\)
梯度跨 \(T\) 步传播 \(\prod_{t} f_t\) \(\prod_{t} (1 - z_t)\)
保持梯度的条件 \(f_t \approx 1\)(遗忘门接近1) \(z_t \approx 0\)(更新门接近0)

当模型学到某段信息需要长期保留时:

  • LSTM:遗忘门 \(f_t \approx 1\),信息在细胞状态中几乎不变地传递
  • GRU:更新门 \(z_t \approx 0\),隐藏状态 \(h_t \approx h_{t-1}\),信息直接"穿过"

两者的机制本质上是等价的:都提供了一条加性的、可控的梯度传播路径,避免了Vanilla RNN中梯度必须经过 \(W_{hh}\) 矩阵乘法的问题。

加性结构是关键

无论LSTM还是GRU,缓解梯度消失的核心都是加法结构:新状态 = 旧状态的加权 + 新信息的加权。加法的梯度是简单的系数,不涉及矩阵乘法的特征值问题。这与ResNet的残差连接 \(y = x + F(x)\) 思想完全一致。


GRU vs LSTM 详细对比

对比维度 GRU LSTM
提出时间 2014 (Cho et al.) 1997 (Hochreiter & Schmidhuber)
门的数量 2个(更新门、重置门) 3个(遗忘门、输入门、输出门)
状态变量 \(h_t\) \(h_t\)\(C_t\)
参数量 \(3 d_h(d_h + d) + 3d_h\) \(4 d_h(d_h + d) + 4d_h\)
参数比例 约为LSTM的 75% 基准 (100%)
训练速度 较快(少一组矩阵运算) 较慢
信息保留/更新 互补约束(\(z_t\)\(1-z_t\) 独立控制(\(f_t\)\(i_t\) 分离)
输出过滤 无(直接输出 \(h_t\) 有(输出门 \(o_t\) 过滤)
表达能力 略弱(互补约束限制) 略强(独立门控更灵活)
大多数任务性能 与LSTM持平 与GRU持平
小数据集 可能更优(参数少,不易过拟合) 可能过拟合
极长序列 略逊 略优(独立cell state更稳定)
代码复杂度 更简单 更复杂

经验法则

实践中,GRU和LSTM的性能差异通常很小。Jozefowicz et al. (2015) 在大规模实验中发现,没有哪个架构在所有任务上都占优。一般建议:先试GRU(更快),如果效果不够好再换LSTM。对于特别长的序列或需要精细控制信息流的任务,LSTM可能更合适。


代码实现(PyTorch)

基本使用

import torch
import torch.nn as nn

# 参数设定
input_size = 10     # 输入特征维度
hidden_size = 20    # 隐藏状态维度
num_layers = 2      # GRU层数
batch_size = 3      # 批大小
seq_len = 5         # 序列长度

# 创建 GRU
gru = nn.GRU(
    input_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    batch_first=True,      # 输入格式: (batch, seq, feature)
    dropout=0.1,           # 层间 dropout(仅在 num_layers > 1 时有效)
    bidirectional=False     # 是否双向
)

# 输入数据
x = torch.randn(batch_size, seq_len, input_size)    # (3, 5, 10)
h0 = torch.zeros(num_layers, batch_size, hidden_size)  # (2, 3, 20)

# 前向传播
output, h_n = gru(x, h0)

print(f"output shape: {output.shape}")   # (3, 5, 20) - 每个时间步的输出
print(f"h_n shape: {h_n.shape}")         # (2, 3, 20) - 最后时间步各层的隐藏状态

nn.GRU 参数说明

参数 类型 说明
input_size int 输入 \(x_t\) 的特征维度
hidden_size int 隐藏状态 \(h_t\) 的维度
num_layers int 堆叠的GRU层数,默认1
bias bool 是否使用偏置,默认True
batch_first bool True则输入为 (batch, seq, feature),默认False
dropout float 层间dropout比率(最后一层不加),默认0
bidirectional bool 是否双向GRU,默认False

输出说明

输出 形状 说明
output (batch, seq, hidden_size * num_directions) 所有时间步最后一层的 \(h_t\)
h_n (num_layers * num_directions, batch, hidden_size) 最后一个时间步各层的 \(h_t\)

基于 GRU 的文本分类模型

class GRUClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_classes,
                 num_layers=1, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(
            embed_dim, hidden_size, num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        # 双向GRU,隐藏状态维度翻倍
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        # x: (batch, seq_len) - token indices
        embeds = self.dropout(self.embedding(x))     # (batch, seq, embed_dim)
        output, h_n = self.gru(embeds)               # output: (batch, seq, hidden*2)

        # 取最后一个时间步的输出(双向拼接)
        # 或者用平均池化: out = output.mean(dim=1)
        out = output[:, -1, :]                        # (batch, hidden*2)
        out = self.dropout(out)
        logits = self.fc(out)                         # (batch, num_classes)
        return logits

# 使用示例
model = GRUClassifier(
    vocab_size=10000, embed_dim=128,
    hidden_size=256, num_classes=2
)
x = torch.randint(0, 10000, (32, 50))  # batch=32, seq_len=50
logits = model(x)                        # (32, 2)

实践提示

使用双向GRU时,h_n 的形状为 (num_layers * 2, batch, hidden_size)。如果要取最后一层双向的隐藏状态并拼接,需要 torch.cat([h_n[-2], h_n[-1]], dim=-1),而不是简单地取 h_n[-1]


思考与讨论

GRU 什么时候优于 LSTM?

  1. 小数据集:GRU参数少约25%,正则化效果天然更好,不易过拟合
  2. 计算资源受限:边缘设备、实时推理等场景,GRU的速度优势明显
  3. 序列不太长:当序列长度在几十到几百的范围内,GRU和LSTM差异极小,此时选择更简洁的GRU更合理
  4. 快速原型验证:GRU训练更快,适合快速迭代实验

为什么两者性能差不多?

核心原因在于两者解决梯度消失的机制本质相同——都是通过加性更新提供线性梯度路径。LSTM多出的输出门和独立细胞状态提供了更精细的控制能力,但在大多数实际任务中,这种额外的表达能力并不是性能的瓶颈。模型性能更多地受数据质量、超参数调优、正则化策略等因素影响。

GRU 在现代深度学习中的地位

随着Transformer架构的崛起,LSTM和GRU在NLP领域的使用都大幅减少。但GRU仍然活跃在以下场景:

  • 时间序列预测:金融数据、传感器数据等,序列长度适中,GRU性价比高
  • 语音处理:实时语音识别、语音合成中的流式处理
  • 强化学习:部分策略网络使用GRU处理部分可观测环境的历史信息
  • 边缘计算/嵌入式设备:模型体积和推理速度至关重要的场景
  • 作为大模型的组件:一些混合架构中,GRU被用作局部的序列建模模块

历史意义

GRU的最大贡献不在于性能的提升,而在于证明了一个重要观点:LSTM的复杂性并非必要。这启发了后续研究者不断探索更简洁的序列建模方案,最终促进了Transformer等新架构的诞生。简洁性本身就是一种价值。


总结

GRU的方程汇总(完整的前向传播,共4个公式):

\[ r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \]
\[ z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \]
\[ \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) \]
\[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

核心要点

  • GRU = LSTM 的简化版:2个门代替3个门,1个状态代替2个状态
  • 更新门 \(z_t\) 同时承担了LSTM遗忘门和输入门的角色(互补约束)
  • 重置门 \(r_t\) 控制候选状态对历史的依赖程度
  • 加性更新结构提供线性梯度路径,缓解梯度消失
  • 参数量约为LSTM的75%,训练更快,性能通常持平

下一步:Seq2Seq(编码器-解码器架构,GRU/LSTM均可作为基础组件)


评论 #