Transformer变体与高效注意力
概述
标准Transformer的自注意力机制复杂度为 \(O(n^2)\),限制了其处理长序列的能力。本章系统梳理各种高效注意力变体,包括IO感知优化(FlashAttention)、线性注意力、稀疏注意力,以及位置编码的演进。
graph TD
A[注意力变体] --> B[IO感知优化]
A --> C[线性/低秩近似]
A --> D[稀疏注意力]
A --> E[位置编码]
A --> F[KV缓存优化]
B --> B1[FlashAttention]
B --> B2[FlashAttention-2/3]
C --> C1[Linformer]
C --> C2[Performer]
C --> C3[Linear Attention]
D --> D1[Longformer]
D --> D2[BigBird]
D --> D3[Sparse Transformer]
E --> E1[RoPE]
E --> E2[ALiBi]
E --> E3[Sinusoidal]
F --> F1[MQA]
F --> F2[GQA]
F --> F3[MLA]
1. 标准注意力回顾
1.1 缩放点积注意力
- 时间复杂度:\(O(n^2 d)\)
- 空间复杂度:\(O(n^2 + nd)\)
其中 \(n\) 是序列长度,\(d\) 是特征维度。
1.2 多头注意力(MHA)
每个头的维度:\(d_k = d_v = d_{\text{model}} / h\)
参数量:\(4 \times d^2\)(Q, K, V, O四个投影矩阵)
2. FlashAttention:IO感知的精确注意力
2.1 动机
标准注意力的瓶颈不是计算,而是内存IO:
- GPU有大量计算单元但有限的HBM带宽
- \(n^2\) 的注意力矩阵需要写入/读取HBM
- 实际运行时间被HBM访问主导
2.2 FlashAttention v1
核心思想:Tiling(分块计算)+ Recomputation(反向传播时重计算)
Tiling算法:
- 将 \(Q, K, V\) 分成块(block),大小适配SRAM
- 在SRAM中计算局部注意力
- 使用在线softmax技巧逐块累加结果
- 不将 \(n \times n\) 注意力矩阵写入HBM
在线Softmax:
复杂度改进:
| 指标 | 标准注意力 | FlashAttention |
|---|---|---|
| HBM读写 | \(O(n^2 d + n^2)\) | \(O(n^2 d^2 / M)\) |
| FLOPS | \(O(n^2 d)\) | \(O(n^2 d)\)(不变) |
| 额外内存 | \(O(n^2)\) | \(O(n)\) |
其中 \(M\) 是SRAM大小。
2.3 FlashAttention v2 & v3
FlashAttention-2:
- 优化线程块分工(减少非矩阵乘法运算)
- 并行化序列长度维度
- 在A100上达到理论FLOPS的50-73%
FlashAttention-3:
- 针对Hopper架构(H100)优化
- 利用异步拷贝和WGMMA指令
- FP8支持
3. 线性注意力与低秩近似
3.1 Linformer
思想:注意力矩阵是低秩的,可以用低维投影近似。
将 \(K\) 和 \(V\) 从 \(n \times d\) 投影到 \(k \times d\)(\(k \ll n\)):
其中 \(E_K, E_V \in \mathbb{R}^{k \times n}\) 是学习的投影矩阵。
复杂度:\(O(nk)\) — 线性于 \(n\)
3.2 Performer
思想:用随机特征近似softmax核,避免显式计算 \(n \times n\) 矩阵。
其中 \(\phi\) 是随机特征映射(如正交随机特征FAVOR+)。
关键公式:
先计算 \(\phi(K)^\top V\)(\(d \times d\) 矩阵),再乘以 \(\phi(Q)\)。
复杂度:\(O(n d^2)\) — 线性于 \(n\)
3.3 方法对比
| 方法 | 复杂度 | 精确? | 训练需要? |
|---|---|---|---|
| 标准注意力 | \(O(n^2 d)\) | 精确 | - |
| FlashAttention | \(O(n^2 d)\) | 精确 | 否(直接替换) |
| Linformer | \(O(nkd)\) | 近似 | 是 |
| Performer | \(O(nd^2)\) | 近似 | 是 |
4. 稀疏注意力
4.1 Longformer
混合注意力模式:
- 滑动窗口注意力:每个token关注局部窗口 \(w\) 个token
- 膨胀滑动窗口:类似空洞卷积,增大感受野
- 全局注意力:特殊token(如[CLS])关注所有token
4.2 BigBird
结合三种稀疏模式:
- 局部注意力(滑动窗口)
- 全局注意力(选定的全局token)
- 随机注意力(随机连接)
理论保证:BigBird是图灵完备的(随机注意力+全局注意力保证了这一点)。
4.3 Sparse Transformer
固定稀疏模式:
- Strided pattern:行注意力 + 列注意力
- 复杂度:\(O(n\sqrt{n})\)
5. 位置编码
5.1 位置编码的必要性
Transformer的自注意力是置换不变的(permutation invariant),必须通过位置编码注入位置信息。
5.2 旋转位置编码(RoPE)
RoPE(Su et al., 2021)是目前LLM中最流行的位置编码。
核心思想:通过旋转矩阵编码相对位置。
对于位置 \(m\) 的查询向量 \(q\) 和位置 \(n\) 的键向量 \(k\):
2D旋转(每两个维度一组):
其中 \(\theta_i = 10000^{-2i/d}\)。
优点:
- 自然编码相对位置
- 外推能力较好(配合NTK-aware scaling等技术)
- 计算高效(逐元素操作)
5.3 ALiBi(Attention with Linear Biases)
ALiBi(Press et al., 2022)不使用位置编码,而是给注意力分数加一个线性偏置:
其中 \(m\) 是每个头不同的斜率(几何序列)。
优点:
- 不增加模型参数
- 天然的长度外推能力
- 计算开销极小
5.4 位置编码对比
| 方法 | 类型 | 外推能力 | 复杂度 | 使用模型 |
|---|---|---|---|---|
| Sinusoidal | 绝对 | 差 | \(O(1)\) | 原始Transformer |
| 可学习 | 绝对 | 差 | \(O(1)\) | BERT, GPT-2 |
| RoPE | 相对 | 中等(可扩展) | \(O(1)\) | LLaMA, Qwen, Gemma |
| ALiBi | 相对偏置 | 好 | \(O(1)\) | BLOOM, MPT |
| T5 bias | 相对 | 中等 | \(O(1)\) | T5 |
6. KV缓存优化
6.1 背景
自回归生成时,每个新token需要与所有历史KV交互。KV缓存随序列增长线性增大,成为内存瓶颈。
标准MHA的KV缓存大小:\(2 \times n_{\text{layers}} \times n_{\text{heads}} \times n_{\text{seq}} \times d_{\text{head}}\)
6.2 多查询注意力(MQA)
MQA(Shazeer, 2019):所有查询头共享一组 K和V。
KV缓存缩减:\(h\) 倍
代价:质量略有下降
6.3 分组查询注意力(GQA)
GQA(Ainslie et al., 2023):折中方案,将查询头分组,每组共享一组KV。
- \(g=1\):退化为MQA
- \(g=h\):退化为标准MHA
典型设置:\(h=32, g=8\)(LLaMA 2 70B)
6.4 Multi-Head Latent Attention(MLA)
MLA(DeepSeek-V2):将KV压缩到低维潜在空间。
只缓存 \(c_t^{KV}\),大幅减少KV缓存。
6.5 对比总结
| 方法 | KV头数 | 缓存大小 | 质量 | 代表模型 |
|---|---|---|---|---|
| MHA | \(h\) | \(2nhd\) | 最好 | GPT-3 |
| MQA | 1 | \(2nd\) | 稍差 | PaLM |
| GQA | \(g\) | \(2ngd\) | 接近MHA | LLaMA 2, Gemma |
| MLA | - | \(2nd_c\) | 好 | DeepSeek-V2 |
7. 其他注意力变体
7.1 滑动窗口注意力
Mistral采用的方法,限制注意力范围到固定窗口:
7.2 Ring Attention
将长序列分布到多个设备,每个设备处理一部分KV,通过环形通信传递KV块。
7.3 Differential Attention(Diff Attention)
减少对无关token的注意力噪声。
8. 实践选择指南
graph TD
A[选择注意力方案] --> B{序列长度?}
B -->|<8K| C[标准MHA + FlashAttention]
B -->|8K-128K| D[GQA + FlashAttention + RoPE]
B -->|>128K| E{精度要求?}
E -->|高| F[Ring Attention + GQA]
E -->|可近似| G[稀疏注意力/线性注意力]
D --> H{推理优化?}
H -->|是| I[MQA/GQA + KV缓存压缩]
H -->|否| J[标准GQA]
9. 总结
| 优化方向 | 方法 | 核心思路 |
|---|---|---|
| IO优化 | FlashAttention | 分块计算,减少HBM访问 |
| 低秩近似 | Linformer, Performer | 压缩K/V或核近似 |
| 稀疏模式 | Longformer, BigBird | 局部+全局+随机 |
| 位置编码 | RoPE, ALiBi | 相对位置,长度外推 |
| KV缓存 | MQA, GQA, MLA | 减少KV头数或压缩 |
当前主流配置(2024-2025):
- 位置编码:RoPE(+ YaRN/NTK扩展长度)
- 注意力:GQA + FlashAttention-2/3
- 推理:PagedAttention + 量化KV缓存
参考文献
- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," NeurIPS 2022
- Wang et al., "Linformer: Self-Attention with Linear Complexity," 2020
- Choromanski et al., "Rethinking Attention with Performers," ICLR 2021
- Beltagy et al., "Longformer: The Long-Document Transformer," 2020
- Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding," 2021
- Press et al., "Train Short, Test Long: Attention with Linear Biases," ICLR 2022
- Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models," EMNLP 2023