Skip to content

RNN原理

思考:

  • 为什么 RNN 无法并行计算?
  • 对于一本书,RNN 真的能记住第一章写了什么吗?
  • CNN处理图像、RNN处理序列,它们的归纳偏置分别是什么?

从MLP/CNN到RNN

为什么需要新架构?

在CNN原理笔记中我们看到,CNN通过两个核心归纳偏置(局部性和平移不变性)解决了图像处理问题。但自然语言、语音、股票价格等数据是序列(Sequence)——元素之间有先后顺序,且长度不固定。

MLP和CNN处理序列时有本质缺陷:

架构 处理序列的问题
MLP 输入维度固定——"我喜欢猫"是3个词,"今天天气真好"是4个词,MLP无法处理不同长度的输入。强行padding浪费且丢失结构
CNN 卷积核只看固定大小的局部窗口(如3×3)。对于"虽然开头无聊,但结局精彩"这种长距离依赖("虽然...但..."跨越了很多词),CNN需要堆叠非常多层才能覆盖,效率极低

核心需求:我们需要一种架构,能够:

  1. 处理任意长度的输入序列
  2. 在处理当前元素时,能"记住"之前看过的内容
  3. 跨时间步共享参数(类似CNN跨空间共享卷积核)

RNN的核心归纳偏置

类比CNN笔记中的分析方法,RNN也有两个核心归纳偏置:

CNN RNN
局部性:像素与周围像素关系最密切 时间依赖性:当前时刻的输出主要取决于当前输入和历史信息
平移不变性:猫在左上角和右下角都是猫 时间平移不变性:处理序列的"规则"在每个时间步都相同(参数共享)

CNN的参数共享是空间上的(同一个卷积核在图像的不同位置使用),RNN的参数共享是时间上的(同一组权重在序列的不同时间步使用)。


数学原理

从MLP到RNN的数学推导

MLP处理序列的尝试

假设我们要处理一个长度为 \(T\) 的序列 \(x_1, x_2, \ldots, x_T\)。最朴素的想法:把所有时间步的输入拼接起来喂给MLP:

\[ y = f(W \cdot [x_1; x_2; \ldots; x_T] + b) \]

问题:\(W\) 的大小取决于 \(T\),不同长度的序列需要不同大小的网络。而且每个时间步的位置绑定了特定的权重,无法泛化。

引入"状态"的思想

如果我们不是一次性看完所有输入,而是逐步处理呢?每一步维护一个"状态"来总结历史信息:

\[ h_t = f(h_{t-1}, x_t) \]

这就是RNN的核心思想:\(h_t\)(隐藏状态)是对 \(x_1, x_2, \ldots, x_t\) 的压缩摘要。

参数共享

类比CNN笔记中的推导——CNN通过让权重不依赖于位置 \((i,j)\) 来实现空间参数共享。RNN通过让权重不依赖于时间步 \(t\) 来实现时间参数共享:

  • 不共享:\(h_t = f(W_t^{hh} \cdot h_{t-1} + W_t^{xh} \cdot x_t + b_t)\)(每个时间步有自己的权重)
  • 共享后:\(h_t = f(W^{hh} \cdot h_{t-1} + W^{xh} \cdot x_t + b)\)(所有时间步用同一组权重)

这使得RNN可以处理任意长度的序列,且参数量与序列长度无关。

Vanilla RNN 的完整公式

隐藏状态更新:

\[ h_t = \tanh(W_{hh} \cdot h_{t-1} + W_{xh} \cdot x_t + b_h) \]

输出计算:

\[ y_t = W_{hy} \cdot h_t + b_y \]

各符号的含义:

符号 维度 含义
\(x_t\) \(\mathbb{R}^d\) \(t\) 步的输入向量(如词嵌入)
\(h_t\) \(\mathbb{R}^{d_h}\) \(t\) 步的隐藏状态(历史信息的摘要)
\(h_{t-1}\) \(\mathbb{R}^{d_h}\) 上一步的隐藏状态
\(W_{xh}\) \(\mathbb{R}^{d_h \times d}\) 输入→隐藏 的权重矩阵
\(W_{hh}\) \(\mathbb{R}^{d_h \times d_h}\) 隐藏→隐藏 的权重矩阵(循环连接)
\(W_{hy}\) \(\mathbb{R}^{d_o \times d_h}\) 隐藏→输出 的权重矩阵
\(b_h, b_y\) 偏置项
\(\tanh\) 激活函数(将值压缩到 \([-1, 1]\)

参数量分析:假设 \(d = 300\)(词嵌入维度),\(d_h = 256\)(隐藏层维度),\(d_o = 10000\)(词表大小):

\[ |W_{xh}| + |W_{hh}| + |W_{hy}| = 256 \times 300 + 256 \times 256 + 10000 \times 256 \approx 2.7M \]

注意:无论序列有多长(10个词还是10000个词),参数量都是 2.7M。这就是参数共享的力量。


架构详解

折叠视图 vs 展开视图

RNN有两种等价的表示方式——折叠视图和展开视图:

RNN折叠与展开视图

图中每个符号的含义

符号 是什么 类比
\(x\)(绿色圆) 输入向量,如一个词的词嵌入 原材料
\(h\)(蓝色方块) 隐藏状态向量,如 \([0.48, -0.36, 0.72]\),是一组数字 加工后的半成品(既可以继续加工,也可以直接出货)
\(o\)(粉色圆) 输出向量,经过 \(W\) 变换后的最终结果 成品
\(U\) 输入→隐藏 的权重矩阵(对应前面公式中的 \(W_{xh}\) 处理原材料的工具
\(V\) 隐藏→隐藏 的权重矩阵(对应前面公式中的 \(W_{hh}\) 利用上次经验的工具
\(W\) 隐藏→输出 的权重矩阵(对应前面公式中的 \(W_{hy}\) 出货检验的工具

符号对应关系

不同教材用不同字母。本图用 \(U, V, W\),前面公式用 \(W_{xh}, W_{hh}, W_{hy}\),Colah's Blog 用 \(A\) 表示整个计算单元。指的是同一个东西,不要被字母搞混。

折叠视图(左半边)

左边是一个带自环的结构:

  • \(x\) 通过权重 \(U\) 进入 \(h\)
  • \(h\) 通过权重 \(W\) 输出 \(o\)
  • \(h\) 同时通过权重 \(V\) 回到自己(自环箭头)

这个自环就是"循环"神经网络名字的由来——\(h\) 的输出会被送回自身,作为下一个时间步的输入。

展开视图(右半边)——核心重点

把自环沿时间轴展开后,可以清晰看到信息如何一步步流动:

逐步阅读:

  1. 时刻 \(t-1\):输入 \(x_{t-1}\) 通过 \(U\) 进入,与前一步传来的隐藏状态结合 → 产出 \(h_{t-1}\)
  2. \(h_{t-1}\) 分两路: - 向上:通过 \(W\) 输出 \(o_{t-1}\)(该时刻的预测结果) - 向右:通过 \(V\) 传递给下一个时刻(这就是"记忆"的传递)
  3. 时刻 \(t\)\(h_{t-1}\) 从左边来(通过 \(V\)),\(x_t\) 从下面来(通过 \(U\)),两者共同决定 \(h_t\)
  4. \(h_t\) 再分两路:向上输出 \(o_t\),向右传给 \(h_{t+1}\)
  5. 时刻 \(t+1\):同理,以此类推...

对应的数学公式

图中每个蓝色方块 \(h\) 内部做的计算:

\[ h_t = \tanh(\underbrace{V \cdot h_{t-1}}_{\text{左边水平箭头传来的}} + \underbrace{U \cdot x_t}_{\text{下方箭头传来的}} + b) \]

图中每个粉色圆 \(o\) 的计算:

\[ o_t = \underbrace{W \cdot h_t}_{\text{向上箭头}} \]

关键理解\(h_t\) 不是神经元,而是一个向量(一组数字)。蓝色方块代表的是一次计算过程(矩阵乘法 + tanh),\(h_t\) 是这次计算的输出结果。\(h_t\) 同时扮演两个角色:(1) 向上经过 \(W\) 变成输出 \(o_t\);(2) 向右经过 \(V\) 成为下一步的输入。

所有时间步共享同一组参数 \(U, V, W\)——这就是RNN的参数共享,类似CNN的卷积核在不同空间位置共享。

展开视图是理解RNN的关键:展开后的RNN看起来就像一个非常深的前馈网络,只不过每一"层"共享相同的权重。这也是为什么RNN会遇到与深层网络类似的梯度问题。

前向传播的完整过程

RNN的前向传播就是普通的矩阵乘法,与MLP的全连接层没有本质区别。以处理序列 "我 喜欢 猫" 为例(假设 \(d=4, d_h=3\)):

初始化\(h_0 = [0, 0, 0]\)(零向量)

第1步:处理 "我"

(a) 拼接(或分别乘再相加,数学等价):

\[ z_1 = W_{hh} \cdot h_0 + W_{xh} \cdot x_1 + b_h \]

展开矩阵乘法:

\[ z_1 = \underbrace{\begin{bmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} \\ w_{31} & w_{32} & w_{33} \end{bmatrix}}_{W_{hh} \in \mathbb{R}^{3 \times 3}} \cdot \underbrace{\begin{bmatrix} 0 \\ 0 \\ 0 \end{bmatrix}}_{h_0} + \underbrace{\begin{bmatrix} u_{11} & u_{12} & u_{13} & u_{14} \\ u_{21} & u_{22} & u_{23} & u_{24} \\ u_{31} & u_{32} & u_{33} & u_{34} \end{bmatrix}}_{W_{xh} \in \mathbb{R}^{3 \times 4}} \cdot \underbrace{\begin{bmatrix} 0.21 \\ -0.45 \\ 0.73 \\ 0.12 \end{bmatrix}}_{x_1} + b_h \]

因为 \(h_0 = \vec{0}\),第一项为零,所以:

\[ z_1 = W_{xh} \cdot x_1 + b_h = [0.52,\ -0.38,\ 0.91] \]

(b) 激活:

\[ h_1 = \tanh(z_1) = [\tanh(0.52),\ \tanh(-0.38),\ \tanh(0.91)] = [0.48,\ -0.36,\ 0.72] \]

这就是一步RNN前向传播的全部内容。 就是一次矩阵乘法 + tanh。

第2步:处理 "喜欢"

现在 \(h_1 \neq \vec{0}\)\(W_{hh}\) 开始起作用了:

\[ z_2 = W_{hh} \cdot \underbrace{[0.48, -0.36, 0.72]}_{h_1} + W_{xh} \cdot \underbrace{[0.65, 0.33, -0.18, 0.51]}_{x_2} + b_h \]

两个矩阵乘法的结果相加:\(W_{hh} \cdot h_1\) 带来了"我"的记忆,\(W_{xh} \cdot x_2\) 带来了"喜欢"的新信息。

\[ z_2 = [0.35, 0.62, -0.15] + [0.71, -0.22, 0.44] + b_h = [1.08, 0.42, 0.31] \]
\[ h_2 = \tanh([1.08, 0.42, 0.31]) = [0.79, 0.40, 0.30] \]

第3步:处理 "猫"——完全相同的过程,最终得到 \(h_3\)

总结:RNN的前向传播 = 在每个时间步重复执行同一个全连接层(两次矩阵乘法 + 一次tanh),上一步的输出作为下一步的额外输入。没有任何超出基础前向传播的操作。

隐藏状态的存储机制(训练 vs 推理)

训练时:每一步算出的 \(h_0, h_1, h_2, h_3\) 全部保存在内存中,不会被覆盖。因为反向传播需要用到所有中间值来计算梯度。所以 RNN 训练的内存开销与序列长度成正比——序列越长,要存的 \(h\) 越多,显存占用越大。

推理时:不需要反向传播,所以只保留最新的 \(h_t\) 即可,之前的 \(h\) 可以丢弃。这就是为什么 RNN 推理的内存开销是常数级的。

计算图与"为什么无法并行"

从展开图可以清晰看出:

\[ h_1 = f(h_0, x_1) \quad \rightarrow \quad h_2 = f(h_1, x_2) \quad \rightarrow \quad h_3 = f(h_2, x_3) \]

\(h_2\) 的计算必须等待 \(h_1\) 完成,\(h_3\) 必须等待 \(h_2\) 完成。这是一个严格的顺序依赖链——无法并行化。

相比之下,CNN的不同位置的卷积可以同时计算(没有依赖关系),Transformer的Self-Attention也可以一次性处理所有位置。这就是RNN训练速度慢的根本原因

架构 计算 \(n\) 个位置的输出 顺序操作数
RNN 必须一步步来 \(O(n)\)
CNN 所有位置并行 \(O(1)\)(单层)
Transformer 所有位置并行 \(O(1)\)

反向传播:BPTT

Backpropagation Through Time

RNN的训练使用沿时间的反向传播(BPTT, Backpropagation Through Time)。核心思想:把展开后的RNN当成一个普通的深层前馈网络,正常做反向传播。

为什么梯度能从 \(h_t\) 一路回传到 \(h_0\)

前向传播时,所有中间隐藏状态 \(h_0, h_1, \ldots, h_t\)保存在内存中(见前向传播章节的说明)。反向传播时,拿着这些存好的值,从后往前逐步计算梯度。\(h\) 并没有被"覆盖"——每个时刻的 \(h_t\) 都是独立存储的变量,不是同一个变量被反复赋值。

假设损失函数是每个时间步损失的总和:

\[ \mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t(y_t, \hat{y}_t) \]

我们需要计算 \(\frac{\partial \mathcal{L}}{\partial W_{hh}}\)。由于 \(W_{hh}\) 在每个时间步都被使用,梯度需要在所有时间步上累加:

\[ \frac{\partial \mathcal{L}}{\partial W_{hh}} = \sum_{t=1}^{T} \frac{\partial \mathcal{L}_t}{\partial W_{hh}} \]

对于某个时间步 \(t\) 的损失 \(\mathcal{L}_t\),它对 \(W_{hh}\) 的梯度需要通过 \(h_t, h_{t-1}, \ldots, h_1\) 一路回溯:

\[ \frac{\partial \mathcal{L}_t}{\partial W_{hh}} = \sum_{k=1}^{t} \frac{\partial \mathcal{L}_t}{\partial h_t} \cdot \frac{\partial h_t}{\partial h_k} \cdot \frac{\partial h_k}{\partial W_{hh}} \]

其中从 \(h_t\) 回溯到 \(h_k\) 的梯度是一个连乘:

\[ \frac{\partial h_t}{\partial h_k} = \prod_{i=k+1}^{t} \frac{\partial h_i}{\partial h_{i-1}} = \prod_{i=k+1}^{t} W_{hh}^T \cdot \text{diag}(\tanh'(z_i)) \]

梯度消失与梯度爆炸

上面那个连乘就是问题的根源。\(\tanh'(x)\) 的值域是 \((0, 1]\),所以每一项都小于等于1。

梯度消失(Vanishing Gradient):当 \(W_{hh}\) 的最大奇异值 < 1时:

\[ \left\| \prod_{i=k+1}^{t} W_{hh}^T \cdot \text{diag}(\tanh'(z_i)) \right\| \leq \prod_{i=k+1}^{t} \|W_{hh}\| \cdot \|\tanh'(z_i)\| \to 0 \]

连乘 \(t - k\) 项,每项都小于1 → 指数级衰减 → 早期时间步的梯度趋近于零

下图直观展示了 BPTT 的过程——黑色箭头是前向传播,红色箭头是梯度反向传播,可以看到梯度需要沿时间轴一路往回走:

BPTT反向传播过程

图中的 \(\theta_h\)\(\theta_x\)\(\theta_{\hat{y}}\) 分别对应三个权重矩阵 \(V\)(隐藏→隐藏)、\(U\)(输入→隐藏)、\(W\)(隐藏→输出)。可以看到,\(\theta_h\) 的梯度需要从 \(h_t\) 一路连乘回 \(h_{t-2}\)\(h_{t-3}\)...,序列越长,连乘次数越多,梯度越容易消失或爆炸。

直觉理解:训练模型处理 "虽然 开头 无聊 但 结局 精彩" 时,\(\mathcal{L}\) 来自最后一步"精彩"的预测。梯度要从"精彩"一路回传到"虽然",经过5次连乘后几乎为零。模型学不到"虽然...但..."的转折关系。

梯度爆炸(Exploding Gradient):当 \(W_{hh}\) 的最大奇异值 > 1时,连乘指数级增长 → 梯度变成天文数字 → 参数更新过大 → 训练崩溃。

梯度爆炸的解决方案:梯度裁剪(Gradient Clipping)

\[ \mathbf{g} \leftarrow \frac{\text{threshold}}{\|\mathbf{g}\|} \cdot \mathbf{g} \quad \text{if } \|\mathbf{g}\| > \text{threshold} \]

当梯度范数超过阈值时,等比例缩小。这是一个简单有效的工程方案。

梯度消失的解决方案:没有简单的工程trick可以解决——需要改变架构。这就催生了LSTM和GRU。


RNN的变体

不同的输入输出结构

RNN的灵活性在于可以适配多种输入输出模式。下图展示了五种经典结构(图源:Andrej Karpathy,红色=输入,绿色=RNN,蓝色=输出):

RNN五种输入输出结构

(1) One-to-One        (2) One-to-Many       (3) Many-to-One
    (普通网络)           (图像描述)             (情感分析)

     x                    x                  x₁ x₂ x₃ x₄
     ↓                    ↓                   ↓  ↓  ↓  ↓
   [NET]               [RNN]→[RNN]→[RNN]   [RNN]→[RNN]→[RNN]→[RNN]
     ↓                    ↓     ↓     ↓                       ↓
     y                   y₁    y₂    y₃                       y


(4) Many-to-Many      (5) Many-to-Many
    (同步,如NER)          (异步,如翻译)

 x₁ x₂ x₃ x₄         x₁ x₂ x₃    y₁ y₂ y₃ y₄
  ↓  ↓  ↓  ↓           ↓  ↓  ↓      ↓  ↓  ↓  ↓
[RNN]→[RNN]→[RNN]→[RNN] [Encoder]→c→[Decoder]
  ↓    ↓    ↓    ↓                    ↓  ↓  ↓  ↓
 y₁   y₂   y₃   y₄                  y₁ y₂ y₃ y₄

第(5)种就是Seq2Seq架构(详见Seq2Seq笔记)。

双向RNN(Bidirectional RNN)

单向RNN只能看到过去的信息。但很多任务需要同时看前后文:

  • NER:"苹果 公司" — 看到"公司"才知道"苹果"是组织名
  • 完形填空:"我 ___ 足球" — 需要同时看前后

BiRNN用两个独立的RNN分别从左到右和从右到左处理序列:

前向:  x₁ ──→ h⃗₁ ──→ h⃗₂ ──→ h⃗₃ ──→ h⃗₄
后向:  x₁ ←── h⃖₁ ←── h⃖₂ ←── h⃖₃ ←── h⃖₄

最终:  h₁ = [h⃗₁; h⃖₁]    h₂ = [h⃗₂; h⃖₂]    ...
       (拼接前向和后向的隐藏状态)
\[ h_t = [\overrightarrow{h_t};\ \overleftarrow{h_t}] \in \mathbb{R}^{2d_h} \]

两个方向的RNN参数独立,不共享。最终每个位置的表示包含了完整的上下文信息。

注意:BiRNN不能用于自回归生成(因为生成第 \(t\) 个词时还没有第 \(t+1\) 个词),主要用于理解任务(分类、标注、编码器等)。

深层RNN(Stacked RNN)

像CNN一样,RNN也可以堆叠多层来增加深度:

层3:  h₁⁽³⁾ ──→ h₂⁽³⁾ ──→ h₃⁽³⁾ ──→ y
       ↑           ↑           ↑
层2:  h₁⁽²⁾ ──→ h₂⁽²⁾ ──→ h₃⁽²⁾
       ↑           ↑           ↑
层1:  h₁⁽¹⁾ ──→ h₂⁽¹⁾ ──→ h₃⁽¹⁾
       ↑           ↑           ↑
      x₁          x₂          x₃

每一层的输入是下一层的隐藏状态。实践中RNN通常只堆2~4层(不像CNN可以堆100+层),因为RNN本身在时间维度上已经很"深"了,再加上层数上的深度,梯度消失会更严重。


RNN的关键局限性

局限 原因 后续解决方案
梯度消失 BPTT中的连乘 → 长距离依赖学不到 LSTM / GRU(门控机制)
无法并行 \(h_t\) 必须等 \(h_{t-1}\) Transformer(Self-Attention)
记忆有限 隐藏状态是固定大小的向量,序列越长信息损失越大 Attention机制
单向信息流 标准RNN只看过去 BiRNN

RNN真的能记住第一章吗? 不能。假设隐藏状态 \(d_h = 256\),那么整本书(可能几十万个词)的信息被压缩在256个浮点数里。理论上信息容量远远不够。而且由于梯度消失,训练时也很难让模型学会保留早期信息。这就是为什么需要LSTM(通过门控选择性记忆)和Attention(直接访问任意历史位置)。


与CNN的对比总结

CNN RNN
适用数据 网格结构(图像) 序列结构(文本、时间序列)
核心假设 局部性 + 空间平移不变性 时间依赖性 + 时间平移不变性
参数共享方式 卷积核在空间上共享 权重在时间步上共享
信息流 局部 → 全局(逐层扩大感受野) 顺序流动(逐步积累历史)
最大路径长度 \(O(\log n)\)(通过堆叠层) \(O(n)\)(必须逐步传递)
并行性 高(不同位置并行) 低(时间步串行)
深度 可以很深(100+层,ResNet) 通常浅(2~4层)

下一步:→ LSTM(通过门控机制解决梯度消失)→ Seq2Seq(编码器-解码器架构)


评论 #