跳转至

状态空间模型(State Space Models)

概述

状态空间模型(SSM)源自控制理论,将连续时间动态系统离散化后用于序列建模。从S4到Mamba,SSM提供了一种线性复杂度的Transformer替代方案,在长序列建模上表现出色。

graph TD
    A[状态空间模型] --> B[连续SSM]
    A --> C[离散化]
    A --> D[结构化SSM]
    A --> E[选择性SSM]

    B --> B1["dx/dt = Ax + Bu"]
    C --> C1[ZOH / 双线性]
    D --> D1[S4 HiPPO]
    D --> D2[S4D / S5]
    E --> E1[Mamba]
    E --> E2[Mamba-2 SSD]

    style E1 fill:#f96,stroke:#333

1. 连续状态空间模型

1.1 基本定义

连续时间SSM定义为:

\[ \dot{x}(t) = Ax(t) + Bu(t) \]
\[ y(t) = Cx(t) + Du(t) \]

其中:

  • \(u(t) \in \mathbb{R}\):输入信号
  • \(x(t) \in \mathbb{R}^N\):隐状态(\(N\) 维)
  • \(y(t) \in \mathbb{R}\):输出信号
  • \(A \in \mathbb{R}^{N \times N}\):状态矩阵(系统动力学)
  • \(B \in \mathbb{R}^{N \times 1}\):输入矩阵
  • \(C \in \mathbb{R}^{1 \times N}\):输出矩阵
  • \(D \in \mathbb{R}\):直通项(通常省略或设为0)

1.2 连续到离散

对于离散序列 \((u_0, u_1, \ldots, u_L)\),需要将连续SSM离散化。

零阶保持(ZOH)

\[ \bar{A} = \exp(\Delta A) \]
\[ \bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B \]

双线性变换

\[ \bar{A} = (I - \Delta A / 2)^{-1}(I + \Delta A / 2) \]
\[ \bar{B} = (I - \Delta A / 2)^{-1} \Delta B \]

其中 \(\Delta\) 是离散化步长。

1.3 离散SSM的递推

\[ x_k = \bar{A} x_{k-1} + \bar{B} u_k \]
\[ y_k = C x_k \]

这是一个线性递推,可以高效计算。

1.4 卷积视角

展开递推:

\[ y_k = C \bar{A}^k \bar{B} u_0 + C \bar{A}^{k-1} \bar{B} u_1 + \cdots + C \bar{B} u_k \]

定义卷积核 \(\bar{K} = (C\bar{B}, C\bar{A}\bar{B}, \ldots, C\bar{A}^{L-1}\bar{B})\)

\[ y = \bar{K} * u \]

双重计算模式

模式 复杂度 适用场景
递推模式 \(O(L)\) 时间,\(O(N)\) 内存 自回归推理
卷积模式 \(O(L \log L)\) 时间 并行训练

2. S4:结构化状态空间

2.1 HiPPO矩阵

S4(Structured State Spaces for Sequence Modeling, Gu et al., 2022)的核心创新是使用HiPPO(High-order Polynomial Projection Operators)初始化 \(A\) 矩阵。

HiPPO-LegS矩阵

\[ A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \]

直觉:HiPPO矩阵使状态 \(x(t)\) 最优地压缩历史输入信号,相当于用正交多项式基函数逼近输入的滑动窗口。

2.2 对角化加速

直接计算 \(\bar{A}^k\) 复杂度高。S4利用 \(A\) 的结构性质:

  1. S4D:将 \(A\) 对角化,\(A = V \Lambda V^{-1}\)
  2. 对角SSM\(\bar{A}_{\text{diag}} = \exp(\Delta \Lambda)\)
  3. 复杂度从 \(O(N^2)\) 降到 \(O(N)\)

2.3 S4的长程建模能力

Long Range Arena基准上,S4大幅超越Transformer:

任务 Transformer S4
ListOps 36.37 58.35
Text 64.27 86.82
Retrieval 57.46 87.09
Image 42.44 88.65
PathFinder 71.40 94.20
Path-X (16K) FAIL 96.35
Average 53.66 86.09

3. Mamba:选择性状态空间

3.1 动机

标准SSM的参数 \((\bar{A}, \bar{B}, C)\) 与输入无关(线性时不变,LTI):

  • 优点:可以用卷积高效训练
  • 缺点:无法进行内容感知推理(如根据输入选择性记忆)

3.2 选择性SSM

Mamba(Gu & Dao, 2023)让 \(B, C, \Delta\) 依赖于输入:

\[ B_t = \text{Linear}_B(x_t), \quad C_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t)) \]

关键

  • \(\Delta_t\) 控制"遗忘门":\(\Delta\) 大 → 更多关注当前输入,\(\Delta\) 小 → 更多保留历史
  • \(B_t\) 控制"输入门":选择性地将信息写入状态
  • \(C_t\) 控制"输出门":选择性地从状态读取信息

与门控RNN的类比

\[ x_t = \bar{A}_t x_{t-1} + \bar{B}_t u_t \quad \longleftrightarrow \quad h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

3.3 硬件感知扫描算法

选择性SSM的参数随时间变化,无法使用卷积。Mamba使用硬件感知的并行扫描算法:

Parallel Scan(前缀和)

递推 \(x_k = a_k x_{k-1} + b_k\) 可以并行化为:

\[ (a_k, b_k) \bullet (a_{k-1}, b_{k-1}) = (a_k a_{k-1}, a_k b_{k-1} + b_k) \]

这个二元运算满足结合律,可以用并行前缀和在 \(O(\log L)\) 深度内完成。

GPU优化

  • 避免在HBM中存储完整的 \((\bar{A}, \bar{B})\) 序列
  • 在SRAM中进行离散化和扫描
  • 类似FlashAttention的核融合策略

3.4 Mamba架构

Mamba Block(替代Transformer Block):

输入 x
├── 线性投影 → 扩展维度
├── 1D深度卷积(kernel=4)
├── SiLU激活
├── 选择性SSM
├── SiLU门控(另一分支)
└── 线性投影 → 输出维度

关键设计

  • 无注意力机制,无MLP块
  • 1D因果卷积提供局部上下文
  • 展开因子(expansion factor)\(E=2\)

3.5 Mamba vs Transformer

特性 Transformer Mamba
训练复杂度 \(O(L^2 d)\) \(O(L d N)\)
推理复杂度 \(O(L)\) per token \(O(1)\) per token
KV缓存 随序列线性增长 固定大小 \(O(N d)\)
长序列能力 受限于 \(L^2\) 线性扩展
上下文学习 较弱
检索能力 较弱

4. Mamba-2:结构化状态空间对偶性

4.1 SSD框架

Mamba-2(Dao & Gu, 2024)揭示了SSM和注意力的对偶性

状态空间对偶性(SSD)

\[ y_t = \sum_{s=1}^{t} C_t^\top A_{t:s} B_s u_s \]

\(A\) 是标量时,可以写成矩阵形式:

\[ Y = (L \odot QK^\top) V \]

其中 \(L\) 是由 \(A\) 的累积积构成的下三角掩码矩阵。

这就是一种结构化的因果注意力!

4.2 分块算法

Mamba-2使用分块(chunking)策略:

  • 将序列分成大小为 \(c\) 的块
  • 块内:使用注意力形式(二次复杂度,但块小)
  • 块间:使用SSM递推(线性复杂度)
\[ \text{总复杂度} = O(L c + L \cdot N) \quad \text{其中} c \text{是块大小} \]

4.3 性能提升

  • 比Mamba-1快2-8倍
  • 支持更大的状态维度
  • 更好的硬件利用率(tensor core友好)

5. SSM的应用

5.1 语言建模

模型 架构 参数 特点
Mamba-1 纯SSM 130M-2.8B 匹配同规模Transformer
Mamba-2 SSD - 更快更好
Jamba Attention+Mamba+MoE 52B 混合架构
Zamba Attention+Mamba 7B 共享注意力层

5.2 视觉

  • Vision Mamba(Vim):Mamba用于图像分类
  • VMamba:2D选择性扫描
  • PlainMamba:简化的视觉Mamba

5.3 其他领域

  • 音频/语音处理
  • 基因组序列建模
  • 时间序列预测
  • 点云处理

6. SSM生态总结

graph LR
    A[HiPPO 2020] --> B[S4 2022]
    B --> C[S4D 2022]
    B --> D[H3 2023]
    D --> E[Mamba 2023]
    E --> F[Mamba-2 2024]

    B --> G[S5 2023]

    E --> H[Jamba 2024]
    E --> I[Vision Mamba 2024]
方法 年份 关键创新
HiPPO 2020 最优状态压缩
S4 2022 结构化SSM + 卷积模式
S4D 2022 对角化简化
H3 2023 SSM + 门控注意力
Mamba 2023 选择性SSM + 硬件感知扫描
Mamba-2 2024 SSD对偶性 + 分块算法

参考文献

  • Gu et al., "Efficiently Modeling Long Sequences with Structured State Spaces," ICLR 2022
  • Gu et al., "On the Parameterization and Initialization of Diagonal State Space Models," NeurIPS 2022
  • Gu & Dao, "Mamba: Linear-Time Sequence Modeling with Selective State Spaces," 2023
  • Dao & Gu, "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality," ICML 2024

评论 #