跳转至

Transformer变体与高效注意力

概述

标准Transformer的自注意力机制复杂度为 \(O(n^2)\),限制了其处理长序列的能力。本章系统梳理各种高效注意力变体,包括IO感知优化(FlashAttention)、线性注意力、稀疏注意力,以及位置编码的演进。

graph TD
    A[注意力变体] --> B[IO感知优化]
    A --> C[线性/低秩近似]
    A --> D[稀疏注意力]
    A --> E[位置编码]
    A --> F[KV缓存优化]

    B --> B1[FlashAttention]
    B --> B2[FlashAttention-2/3]

    C --> C1[Linformer]
    C --> C2[Performer]
    C --> C3[Linear Attention]

    D --> D1[Longformer]
    D --> D2[BigBird]
    D --> D3[Sparse Transformer]

    E --> E1[RoPE]
    E --> E2[ALiBi]
    E --> E3[Sinusoidal]

    F --> F1[MQA]
    F --> F2[GQA]
    F --> F3[MLA]

1. 标准注意力回顾

1.1 缩放点积注意力

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V \]
  • 时间复杂度:\(O(n^2 d)\)
  • 空间复杂度:\(O(n^2 + nd)\)

其中 \(n\) 是序列长度,\(d\) 是特征维度。

1.2 多头注意力(MHA)

\[ \text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \]
\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]

每个头的维度:\(d_k = d_v = d_{\text{model}} / h\)

参数量\(4 \times d^2\)(Q, K, V, O四个投影矩阵)


2. FlashAttention:IO感知的精确注意力

2.1 动机

标准注意力的瓶颈不是计算,而是内存IO

  • GPU有大量计算单元但有限的HBM带宽
  • \(n^2\) 的注意力矩阵需要写入/读取HBM
  • 实际运行时间被HBM访问主导

2.2 FlashAttention v1

核心思想:Tiling(分块计算)+ Recomputation(反向传播时重计算)

Tiling算法

  1. \(Q, K, V\) 分成块(block),大小适配SRAM
  2. 在SRAM中计算局部注意力
  3. 使用在线softmax技巧逐块累加结果
  4. 不将 \(n \times n\) 注意力矩阵写入HBM

在线Softmax

\[ m_{\text{new}} = \max(m_{\text{old}}, \max(\mathbf{x}_{\text{block}})) \]
\[ \ell_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \ell_{\text{old}} + \sum_j e^{x_j - m_{\text{new}}} \]

复杂度改进

指标 标准注意力 FlashAttention
HBM读写 \(O(n^2 d + n^2)\) \(O(n^2 d^2 / M)\)
FLOPS \(O(n^2 d)\) \(O(n^2 d)\)(不变)
额外内存 \(O(n^2)\) \(O(n)\)

其中 \(M\) 是SRAM大小。

2.3 FlashAttention v2 & v3

FlashAttention-2

  • 优化线程块分工(减少非矩阵乘法运算)
  • 并行化序列长度维度
  • 在A100上达到理论FLOPS的50-73%

FlashAttention-3

  • 针对Hopper架构(H100)优化
  • 利用异步拷贝和WGMMA指令
  • FP8支持

3. 线性注意力与低秩近似

3.1 Linformer

思想:注意力矩阵是低秩的,可以用低维投影近似。

\(K\)\(V\)\(n \times d\) 投影到 \(k \times d\)\(k \ll n\)):

\[ \text{Linformer}(Q, K, V) = \text{softmax}\left(\frac{Q(E_K K)^\top}{\sqrt{d_k}}\right)(E_V V) \]

其中 \(E_K, E_V \in \mathbb{R}^{k \times n}\) 是学习的投影矩阵。

复杂度\(O(nk)\) — 线性于 \(n\)

3.2 Performer

思想:用随机特征近似softmax核,避免显式计算 \(n \times n\) 矩阵。

\[ \text{softmax}(QK^\top) \approx \phi(Q) \phi(K)^\top \]

其中 \(\phi\) 是随机特征映射(如正交随机特征FAVOR+)。

关键公式

\[ \text{Attention} \approx \phi(Q) (\phi(K)^\top V) \]

先计算 \(\phi(K)^\top V\)\(d \times d\) 矩阵),再乘以 \(\phi(Q)\)

复杂度\(O(n d^2)\) — 线性于 \(n\)

3.3 方法对比

方法 复杂度 精确? 训练需要?
标准注意力 \(O(n^2 d)\) 精确 -
FlashAttention \(O(n^2 d)\) 精确 否(直接替换)
Linformer \(O(nkd)\) 近似
Performer \(O(nd^2)\) 近似

4. 稀疏注意力

4.1 Longformer

混合注意力模式

  1. 滑动窗口注意力:每个token关注局部窗口 \(w\) 个token
  2. 膨胀滑动窗口:类似空洞卷积,增大感受野
  3. 全局注意力:特殊token(如[CLS])关注所有token
\[ \text{复杂度} = O(n \times w) \quad \text{(线性于}n\text{)} \]

4.2 BigBird

结合三种稀疏模式:

  1. 局部注意力(滑动窗口)
  2. 全局注意力(选定的全局token)
  3. 随机注意力(随机连接)

理论保证:BigBird是图灵完备的(随机注意力+全局注意力保证了这一点)。

4.3 Sparse Transformer

固定稀疏模式

  • Strided pattern:行注意力 + 列注意力
  • 复杂度:\(O(n\sqrt{n})\)

5. 位置编码

5.1 位置编码的必要性

Transformer的自注意力是置换不变的(permutation invariant),必须通过位置编码注入位置信息。

5.2 旋转位置编码(RoPE)

RoPE(Su et al., 2021)是目前LLM中最流行的位置编码。

核心思想:通过旋转矩阵编码相对位置。

对于位置 \(m\) 的查询向量 \(q\) 和位置 \(n\) 的键向量 \(k\)

\[ \langle f_q(x_m, m), f_k(x_n, n) \rangle = g(x_m, x_n, m-n) \]

2D旋转(每两个维度一组):

\[ R_\Theta(m) = \begin{bmatrix} \cos m\theta_1 & -\sin m\theta_1 & & \\ \sin m\theta_1 & \cos m\theta_1 & & \\ & & \ddots & \\ & & & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ & & & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{bmatrix} \]

其中 \(\theta_i = 10000^{-2i/d}\)

优点

  • 自然编码相对位置
  • 外推能力较好(配合NTK-aware scaling等技术)
  • 计算高效(逐元素操作)

5.3 ALiBi(Attention with Linear Biases)

ALiBi(Press et al., 2022)不使用位置编码,而是给注意力分数加一个线性偏置:

\[ \text{softmax}(q_i^\top k_j - m \cdot |i - j|) \]

其中 \(m\) 是每个头不同的斜率(几何序列)。

优点

  • 不增加模型参数
  • 天然的长度外推能力
  • 计算开销极小

5.4 位置编码对比

方法 类型 外推能力 复杂度 使用模型
Sinusoidal 绝对 \(O(1)\) 原始Transformer
可学习 绝对 \(O(1)\) BERT, GPT-2
RoPE 相对 中等(可扩展) \(O(1)\) LLaMA, Qwen, Gemma
ALiBi 相对偏置 \(O(1)\) BLOOM, MPT
T5 bias 相对 中等 \(O(1)\) T5

6. KV缓存优化

6.1 背景

自回归生成时,每个新token需要与所有历史KV交互。KV缓存随序列增长线性增大,成为内存瓶颈。

标准MHA的KV缓存大小:\(2 \times n_{\text{layers}} \times n_{\text{heads}} \times n_{\text{seq}} \times d_{\text{head}}\)

6.2 多查询注意力(MQA)

MQA(Shazeer, 2019):所有查询头共享一组 K和V。

\[ \text{MQA}: h \text{ 个Q头}, 1 \text{ 个K头}, 1 \text{ 个V头} \]

KV缓存缩减\(h\)

代价:质量略有下降

6.3 分组查询注意力(GQA)

GQA(Ainslie et al., 2023):折中方案,将查询头分组,每组共享一组KV。

\[ \text{GQA-}g: h \text{ 个Q头}, g \text{ 个K头}, g \text{ 个V头} \]
  • \(g=1\):退化为MQA
  • \(g=h\):退化为标准MHA

典型设置\(h=32, g=8\)(LLaMA 2 70B)

6.4 Multi-Head Latent Attention(MLA)

MLA(DeepSeek-V2):将KV压缩到低维潜在空间。

\[ k_t = W^{UK} c_t^{KV}, \quad v_t = W^{UV} c_t^{KV} \]
\[ c_t^{KV} = W^{DKV} x_t \in \mathbb{R}^{d_c}, \quad d_c \ll d \]

只缓存 \(c_t^{KV}\),大幅减少KV缓存。

6.5 对比总结

方法 KV头数 缓存大小 质量 代表模型
MHA \(h\) \(2nhd\) 最好 GPT-3
MQA 1 \(2nd\) 稍差 PaLM
GQA \(g\) \(2ngd\) 接近MHA LLaMA 2, Gemma
MLA - \(2nd_c\) DeepSeek-V2

7. 其他注意力变体

7.1 滑动窗口注意力

Mistral采用的方法,限制注意力范围到固定窗口:

\[ \text{Attention}(q_i) = \text{softmax}\left(\frac{q_i k_{[i-w:i]}^\top}{\sqrt{d}}\right) v_{[i-w:i]} \]

7.2 Ring Attention

将长序列分布到多个设备,每个设备处理一部分KV,通过环形通信传递KV块。

7.3 Differential Attention(Diff Attention)

\[ \text{DiffAttn}(X) = (\text{softmax}(Q_1 K_1^\top) - \lambda \cdot \text{softmax}(Q_2 K_2^\top)) V \]

减少对无关token的注意力噪声。


8. 实践选择指南

graph TD
    A[选择注意力方案] --> B{序列长度?}
    B -->|<8K| C[标准MHA + FlashAttention]
    B -->|8K-128K| D[GQA + FlashAttention + RoPE]
    B -->|>128K| E{精度要求?}
    E -->|高| F[Ring Attention + GQA]
    E -->|可近似| G[稀疏注意力/线性注意力]

    D --> H{推理优化?}
    H -->|是| I[MQA/GQA + KV缓存压缩]
    H -->|否| J[标准GQA]

9. 总结

优化方向 方法 核心思路
IO优化 FlashAttention 分块计算,减少HBM访问
低秩近似 Linformer, Performer 压缩K/V或核近似
稀疏模式 Longformer, BigBird 局部+全局+随机
位置编码 RoPE, ALiBi 相对位置,长度外推
KV缓存 MQA, GQA, MLA 减少KV头数或压缩

当前主流配置(2024-2025):

  • 位置编码:RoPE(+ YaRN/NTK扩展长度)
  • 注意力:GQA + FlashAttention-2/3
  • 推理:PagedAttention + 量化KV缓存

参考文献

  • Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," NeurIPS 2022
  • Wang et al., "Linformer: Self-Attention with Linear Complexity," 2020
  • Choromanski et al., "Rethinking Attention with Performers," ICLR 2021
  • Beltagy et al., "Longformer: The Long-Document Transformer," 2020
  • Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding," 2021
  • Press et al., "Train Short, Test Long: Attention with Linear Biases," ICLR 2022
  • Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models," EMNLP 2023

评论 #