Skip to content

LSTM

思考:

  • LSTM 的"细胞状态"和"隐藏状态"分别扮演什么角色?为什么需要两条信息通路?
  • 三个门(遗忘门、输入门、输出门)各自解决什么问题?如果去掉其中一个会怎样?
  • LSTM 为什么能缓解梯度消失?它能完全避免梯度爆炸吗?

背景与动机

RNN 的长期依赖问题

在 RNN 原理笔记中我们看到,Vanilla RNN 通过隐藏状态 \(h_t\) 在时间步之间传递信息。但这条唯一的信息通路存在严重缺陷:

\[ h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b) \]

每经过一个时间步,信息都要通过 \(\tanh\)(值域压缩到 \([-1,1]\))和矩阵乘法 \(W_{hh}\)。在 BPTT 反向传播时,梯度需要经过连乘:

\[ \frac{\partial L}{\partial h_k} = \frac{\partial L}{\partial h_T} \prod_{t=k+1}^{T} W_{hh}^T \cdot \text{diag}(\tanh'(h_t)) \]

由于 \(\tanh' \in (0, 1]\),这个连乘项会随着时间步 \(T - k\) 的增大而指数衰减,导致早期时间步的梯度趋近于零——这就是梯度消失问题。

直觉理解:想象你在传话游戏中,每传一个人信息都会被"压缩"一次。传了50个人之后,原始信息几乎完全丢失。Vanilla RNN 就是这样——它无法记住几十步之前的信息。

LSTM 的核心思路

Long Short-Term Memory(长短期记忆网络)由 Hochreiter & Schmidhuber 于 1997 年提出,核心思想是:

核心设计思想

在隐藏状态 \(h_t\) 之外,增加一条独立的细胞状态(Cell State)\(C_t\) 作为"信息高速公路"。信息在这条高速公路上只经过逐元素乘法和加法(线性运算),不经过挤压性的激活函数,因此梯度可以几乎无损地长距离传播。

同时,引入三个门控机制来精细控制信息的流动:

  • 遗忘门(Forget Gate):决定丢弃哪些旧信息
  • 输入门(Input Gate):决定写入哪些新信息
  • 输出门(Output Gate):决定输出哪些信息

这三个门都是通过 sigmoid 函数输出 0~1 之间的值,实现"软开关"控制。


LSTM 架构详解

完整架构总览

下图展示了一个 LSTM Cell 的完整内部结构(图源:GeeksForGeeks):

LSTM Cell 完整架构

图中符号说明

符号 含义
Sig Sigmoid 激活函数,输出 \(\in (0, 1)\)
tanh Tanh 激活函数,输出 \(\in (-1, 1)\)
\(\otimes\) 逐元素乘法(Hadamard Product)
\(\oplus\) 逐元素加法
直线箭头 向量连接与数据流向

LSTM Cell 的输入与输出

  • 输入:当前时间步的输入 \(x_t\)、上一步的隐藏状态 \(h_{t-1}\)、上一步的细胞状态 \(C_{t-1}\)
  • 输出:当前的隐藏状态 \(h_t\)、当前的细胞状态 \(C_t\)

注意 LSTM 有两条信息通路

  1. 细胞状态 \(C_t\)(图顶部水平线):长期记忆通道,信息通过线性运算传递,梯度可以长距离传播
  2. 隐藏状态 \(h_t\)(图底部水平线):短期记忆通道,暴露给外部的输出

第一步:遗忘门(Forget Gate)——"丢弃什么旧记忆?"

遗忘门

遗忘门决定从细胞状态中丢弃多少旧信息。它查看上一步的隐藏状态 \(h_{t-1}\) 和当前输入 \(x_t\),输出一个 0~1 之间的向量:

\[ f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \]

其中 \([h_{t-1}, x_t]\) 表示将两个向量拼接(concatenate)。

\(f_t\) 的每个元素控制 \(C_{t-1}\) 对应维度的保留比例:

  • \(f_t[i] \approx 1\)完全保留\(i\) 维的记忆("记住")
  • \(f_t[i] \approx 0\)完全遗忘\(i\) 维的记忆("忘掉")

例子:在语言模型中,当遇到新的主语"她"时,遗忘门可能会决定忘掉之前主语"他"对应的性别信息。


第二步:输入门(Input Gate)——"写入什么新记忆?"

输入门

输入门决定向细胞状态中写入什么新信息。这分两个子步骤:

子步骤 A:决定"写入多少"

\[ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \]

Sigmoid 输出 0~1 的值,控制每个维度的写入强度。

子步骤 B:生成"要写入的内容"

\[ \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \]

Tanh 输出 \(-1\)\(1\) 的候选值向量,代表可能写入的新信息。

最终写入的新信息为 \(i_t \odot \tilde{C}_t\)(逐元素相乘),即用 \(i_t\) 过滤 \(\tilde{C}_t\)


第三步:更新细胞状态——核心等式

有了遗忘门和输入门的结果,现在更新细胞状态:

\[ \boxed{C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t} \]

这是 LSTM 最关键的等式。它的含义非常直观:

\[ \text{新记忆} = \text{遗忘系数} \times \text{旧记忆} + \text{写入系数} \times \text{候选新内容} \]

注意这里只涉及逐元素乘法和加法——没有矩阵乘法,没有激活函数的挤压。这条"高速公路"让信息和梯度可以长距离传输,这是 LSTM 解决梯度消失的关键。


第四步:输出门(Output Gate)——"输出什么?"

输出门

输出门决定从当前细胞状态中输出什么信息作为隐藏状态:

\[ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \]
\[ h_t = o_t \odot \tanh(C_t) \]

细胞状态先经过 \(\tanh\) 映射到 \([-1, 1]\),再由输出门过滤。\(h_t\) 就是暴露给外部的隐藏状态,用于:

  1. 传递给下一个时间步的 LSTM Cell
  2. 作为当前时间步的输出(如果是 many-to-many 任务)
  3. 作为最终输出(如果是 many-to-one 任务的最后一步)

公式汇总

步骤 公式 激活函数 作用
遗忘门 \(f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)\) Sigmoid 控制旧记忆的保留比例
输入门 \(i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)\) Sigmoid 控制新信息的写入强度
候选值 \(\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)\) Tanh 生成候选新信息
细胞更新 \(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\) 融合旧记忆与新信息
输出门 \(o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)\) Sigmoid 控制输出哪些信息
隐藏状态 \(h_t = o_t \odot \tanh(C_t)\) Tanh 生成当前时间步的输出

参数量:四组权重矩阵 \(W_f, W_i, W_C, W_o \in \mathbb{R}^{d_h \times (d_h + d)}\) 和四组偏置。总参数量约为 \(4 \times d_h \times (d_h + d)\),是 Vanilla RNN 的 4 倍。


前向传播:完整数值示例

设定:\(d = 4\)(输入维度),\(d_h = 3\)(隐藏状态维度)。

处理序列 "我 喜欢 猫" 的第一个词 "我":

输入\(x_1 = [0.21, -0.45, 0.73, 0.12]\)\(h_0 = [0, 0, 0]\)\(C_0 = [0, 0, 0]\)

Step 1:拼接

\[ [h_0; x_1] = [0, 0, 0, 0.21, -0.45, 0.73, 0.12] \in \mathbb{R}^{7} \]

Step 2:四次矩阵乘法(四个门各一次)

每个门的计算都是:7维输入 \(\rightarrow\) \(W \in \mathbb{R}^{3 \times 7}\) 矩阵乘法 \(\rightarrow\) 加偏置 \(\rightarrow\) 激活函数 \(\rightarrow\) 3维输出。

\[ f_1 = \sigma(W_f \cdot [h_0; x_1] + b_f) = [0.82,\ 0.15,\ 0.91] \]
\[ i_1 = \sigma(W_i \cdot [h_0; x_1] + b_i) = [0.31,\ 0.72,\ 0.08] \]
\[ \tilde{C}_1 = \tanh(W_C \cdot [h_0; x_1] + b_C) = [0.45,\ -0.38,\ 0.79] \]
\[ o_1 = \sigma(W_o \cdot [h_0; x_1] + b_o) = [0.62,\ 0.41,\ 0.73] \]

Step 3:更新细胞状态

\[ C_1 = f_1 \odot C_0 + i_1 \odot \tilde{C}_1 \]
\[ = [0.82, 0.15, 0.91] \odot [0, 0, 0] + [0.31, 0.72, 0.08] \odot [0.45, -0.38, 0.79] \]
\[ = [0, 0, 0] + [0.14, -0.27, 0.06] = [0.14, -0.27, 0.06] \]

Step 4:计算隐藏状态

\[ h_1 = o_1 \odot \tanh(C_1) = [0.62, 0.41, 0.73] \odot \tanh([0.14, -0.27, 0.06]) \]
\[ = [0.62, 0.41, 0.73] \odot [0.14, -0.26, 0.06] = [0.09, -0.11, 0.04] \]

然后处理 "喜欢" 时,用 \(h_1\)\(C_1\) 作为输入,完全相同的流程再走一遍。


为什么 LSTM 能解决梯度消失?

关键在于细胞状态更新的加法结构

\[ C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \]

梯度沿细胞状态回传时:

\[ \frac{\partial C_t}{\partial C_{t-1}} = f_t \]

对比 Vanilla RNN

Vanilla RNN LSTM(沿细胞状态)
梯度传播 \(\frac{\partial h_t}{\partial h_{t-1}} = W_{hh}^T \cdot \text{diag}(\tanh')\) \(\frac{\partial C_t}{\partial C_{t-1}} = f_t\)
连乘 \(\prod W_{hh}^T \cdot \text{diag}(\tanh')\) \(\rightarrow\) 不可控 \(\prod f_t\) \(\rightarrow\) 可学习控制
问题 \(W_{hh}\) 的特征值决定梯度命运 \(f_t\) 是 sigmoid 输出,可以学到接近 1

当模型学到某条信息需要长期保留时,遗忘门会输出接近 1 的值(\(f_t \approx 1\)),梯度可以几乎无损地穿过任意长的时间步。这就像在高速公路上畅通无阻,而 Vanilla RNN 则像走一条每隔几米就有收费站的乡间小路。

LSTM 并不能完全避免梯度爆炸

LSTM 通过加法结构缓解了梯度消失,但梯度爆炸仍然可能发生(特别是当遗忘门接近 1 且多个时间步的梯度累加时)。实践中需要配合梯度裁剪(Gradient Clipping)来防止梯度爆炸。


与 GRU 的对比

GRU(Gated Recurrent Unit, Cho et al., 2014)是 LSTM 的简化版本,合并了遗忘门和输入门为一个更新门,并去掉了独立的细胞状态。详见 GRU 笔记。

LSTM GRU
门的数量 3个(遗忘、输入、输出) 2个(更新、重置)
状态 \(h_t\)\(C_t\) 两个 只有 \(h_t\) 一个
参数量 \(4 d_h(d_h + d)\) \(3 d_h(d_h + d)\)(少约25%)
性能 极长序列略优 大多数任务持平
训练速度 较慢 较快

实践中两者差异不大。如果数据集小或序列不太长,GRU 可能更合适;如果序列很长且计算资源充足,LSTM 通常是更安全的选择。


实际应用

LSTM 在 Transformer 出现之前(2017年以前),是深度学习处理序列数据的绝对主力。其典型应用包括:

  • 语言建模与机器翻译:GPT 系列的前身就是基于 LSTM 的语言模型
  • 语音识别:将语音信号序列转化为文字
  • 时间序列预测:股价预测、天气预报等,LSTM 擅长捕捉长周期模式(如季节性趋势)
  • 异常检测:在时间序列中识别异常模式
  • 视频分析:结合 CNN 提取帧特征,LSTM 建模时序关系

实践建议

超参数选择

参数 常见范围 说明
隐藏状态维度 \(d_h\) 128 ~ 512 任务越复杂越大,但太大容易过拟合
层数 1 ~ 3 超过3层收益递减,需配合残差连接
Dropout 0.2 ~ 0.5 作用在层间,不作用在时间步间
学习率 1e-3 ~ 1e-2 Adam 优化器通常用 1e-3
梯度裁剪 1.0 ~ 5.0 防止梯度爆炸

常见技巧

  1. 初始化遗忘门偏置为正值(如 \(b_f = 1\)):让模型一开始倾向于"记住"而不是"遗忘",有助于训练稳定性
  2. 使用双向 LSTM:除非是生成任务,否则 BiLSTM 几乎总是优于单向
  3. 最后一步 vs 平均池化:情感分析等任务中,对所有时间步的隐藏状态取平均(Mean Pooling)有时比只用最后一步效果更好
  4. 预训练词嵌入:用 Word2Vec 或 GloVe 初始化嵌入层,比随机初始化收敛更快

LSTM 的历史地位

时期 地位
1997-2013 提出后长期被忽视(算力不足、数据不够)
2013-2017 随着深度学习兴起成为序列建模的绝对主力,统治 NLP、语音、时间序列
2017-至今 被 Transformer 逐步取代(NLP 领域几乎完全替换),但在小数据、长时间序列、边缘设备上仍有应用

LSTM 最大的贡献不仅是一个具体架构,更是门控(Gating)思想的开创——这个思想后来影响了 GRU、Highway Network、残差连接,乃至 Transformer 中的各种门控变体。

下一步:\(\rightarrow\) Seq2Seq(用两个 LSTM 组成编码器-解码器,解决序列到序列问题)


评论 #