跳转至

长序列建模

概述

处理超长序列(数万到数百万token)是序列模型的核心挑战。本章介绍Transformer之外的长序列建模方法,包括xLSTM、RWKV、Hyena等,以及它们与Transformer和SSM的权衡。


1. xLSTM:扩展的LSTM

1.1 动机

经典LSTM的局限:存储容量固定、无法并行化、缩放困难。

xLSTM(Beck et al., 2024)通过两个改进突破这些限制。

1.2 sLSTM(标量LSTM)

引入指数门控:

\[ f_t = \exp(\tilde{f}_t), \quad i_t = \exp(\tilde{i}_t) \]

使用归一化器稳定计算:

\[ n_t = f_t n_{t-1} + i_t \]
\[ h_t = o_t \odot \frac{c_t}{n_t} \]

1.3 mLSTM(矩阵LSTM)

将标量cell state扩展为矩阵

\[ C_t = f_t C_{t-1} + i_t v_t k_t^\top \]
\[ h_t = o_t \odot \frac{C_t q_t}{\max(|n_t^\top q_t|, 1)} \]

其中 \(n_t = f_t n_{t-1} + i_t k_t\)

关键:矩阵cell state大幅提升存储容量,类似线性注意力。

1.4 xLSTM架构

交替使用sLSTM和mLSTM块:

  • sLSTM块:sLSTM + 门控残差 + LayerNorm
  • mLSTM块:mLSTM + 前馈 + LayerNorm
  • mLSTM可以完全并行化训练

2. RWKV:线性注意力RNN

2.1 核心思想

RWKV(Peng et al., 2023)结合了Transformer的训练并行性和RNN的推理效率。

2.2 WKV机制

RWKV的核心是WKV(Weighted Key-Value)操作:

\[ \text{wkv}_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} v_i + e^{u+k_t} v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} + e^{u+k_t}} \]

其中 \(w\) 是位置衰减权重,\(u\) 是当前token的bonus。

递推形式

\[ a_t = e^{-w} a_{t-1} + e^{k_t} v_t \]
\[ b_t = e^{-w} b_{t-1} + e^{k_t} \]
\[ \text{wkv}_t = \frac{a_{t-1} + e^{u+k_t} v_t}{b_{t-1} + e^{u+k_t}} \]

2.3 RWKV架构

Time-Mixing(类似注意力):

\[ r_t = W_r \cdot (\mu_r x_t + (1-\mu_r) x_{t-1}) \]

类似地定义 \(k_t, v_t\),然后:

\[ o_t = \sigma(r_t) \odot \text{wkv}_t \]

Channel-Mixing(类似FFN):

\[ o_t = \sigma(r_t) \odot (W_v \cdot \max(k_t, 0)^2) \]

2.4 RWKV版本演进

版本 关键改进
RWKV-4 基础WKV
RWKV-5 (Eagle) 矩阵值状态
RWKV-6 (Finch) 数据依赖的衰减
RWKV-7 进一步改进

3. Hyena:隐式长卷积

3.1 核心思想

Hyena(Poli et al., 2023)用隐式参数化的长卷积替代注意力。

3.2 Hyena算子

Hyena通过交替的逐元素乘法和长卷积实现高阶交互:

\[ y = x_1 \odot (h_1 * (x_2 \odot (h_2 * \cdots))) \]

其中 \(h_i\) 是隐式参数化的卷积核(用MLP生成):

\[ h(t) = \text{Window}(t) \cdot \text{MLP}_\phi(\text{PositionalEncoding}(t)) \]

3.3 计算复杂度

  • 使用FFT实现长卷积:\(O(L \log L)\)
  • 没有二次复杂度
  • 总复杂度:\(O(NL \log L)\)\(N\) 是Hyena阶数

4. 线性注意力

4.1 标准注意力到线性注意力

标准注意力:

\[ \text{Attn}(Q, K, V) = \text{softmax}(QK^\top) V \]

线性注意力(Katharopoulos et al., 2020):

\[ \text{LinAttn}(Q, K, V) = \phi(Q)(\phi(K)^\top V) \]

通过先算 \(\phi(K)^\top V\)\(d \times d\)),实现 \(O(Ld^2)\) 复杂度。

4.2 线性注意力的递推形式

\[ S_t = S_{t-1} + \phi(k_t) v_t^\top \]
\[ y_t = \phi(q_t)^\top S_t \]

这就是一个线性RNN!\(S_t\)\(d \times d\) 的状态矩阵。

4.3 线性注意力的问题

  • 没有softmax归一化 → 性能不如标准注意力
  • 特征映射 \(\phi\) 的选择影响很大
  • RetNet、GLA等方法在此基础上改进

5. 其他方法

5.1 RetNet

RetNet(Sun et al., 2023):保留网络,支持三种计算模式。

  • 并行模式:类似注意力(训练)
  • 递推模式:类似RNN(推理)
  • 分块模式:混合(长序列训练)

核心:带衰减的注意力:

\[ \text{Retention}(Q, K, V) = (QK^\top \odot D) V \]

其中 \(D_{ij} = \gamma^{i-j}\)\(i \geq j\)),\(\gamma\) 是衰减因子。

5.2 GLA(Gated Linear Attention)

\[ S_t = G_t \odot S_{t-1} + k_t v_t^\top \]

添加数据依赖的门控 \(G_t\),提升选择性。


6. Long Range Arena基准

6.1 任务描述

任务 序列长度 测试能力
ListOps 2K 层次推理
Text 4K 文本分类
Retrieval 8K 文档匹配
Image 1K 图像分类(展平为序列)
PathFinder 1K 空间推理
Path-X 16K 超长空间推理

6.2 方法对比

方法 平均 训练模式 推理
Transformer 53.66 并行 \(O(L)\)/token
S4 86.09 并行(卷积) \(O(1)\)/token
Mamba ~87 并行(扫描) \(O(1)\)/token
Hyena ~85 并行(FFT) \(O(\log L)\)/token
RWKV ~83 并行 \(O(1)\)/token

7. SSM vs Transformer权衡

维度 Transformer SSM/线性模型
长序列 \(O(L^2)\) 限制 \(O(L)\)\(O(L\log L)\)
上下文学习 较弱(固定状态瓶颈)
精确检索 强(全局注意力) 弱(信息被压缩)
训练并行性 好(卷积/扫描)
推理效率 KV缓存增长 固定状态
硬件友好 矩阵乘法 需要专门实现
生态成熟度 发展中

共识趋势:混合架构(如Jamba、Zamba)结合两者优势。


8. 总结

核心要点

  1. xLSTM复兴了LSTM,用矩阵状态和指数门控扩展能力
  2. RWKV实现了RNN的训练并行性,接近Transformer性能
  3. Hyena用隐式长卷积替代注意力
  4. 线性注意力是SSM和注意力的桥梁
  5. 混合架构是当前趋势

参考文献

  • Beck et al., "xLSTM: Extended Long Short-Term Memory," 2024
  • Peng et al., "RWKV: Reinventing RNNs for the Transformer Era," EMNLP 2023
  • Poli et al., "Hyena Hierarchy: Towards Larger Convolutional Language Models," ICML 2023
  • Katharopoulos et al., "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention," ICML 2020
  • Sun et al., "Retentive Network: A Successor to Transformer for Large Language Models," 2023

评论 #