Skip to content

Transformer Variants and Efficient Attention

Overview

The standard Transformer's self-attention mechanism has \(O(n^2)\) complexity, limiting its ability to handle long sequences. This chapter systematically covers efficient attention variants including IO-aware optimization (FlashAttention), linear attention, sparse attention, and the evolution of positional encoding.

graph TD
    A[Attention Variants] --> B[IO-Aware Optimization]
    A --> C[Linear/Low-Rank Approximation]
    A --> D[Sparse Attention]
    A --> E[Positional Encoding]
    A --> F[KV Cache Optimization]

    B --> B1[FlashAttention]
    B --> B2[FlashAttention-2/3]

    C --> C1[Linformer]
    C --> C2[Performer]
    C --> C3[Linear Attention]

    D --> D1[Longformer]
    D --> D2[BigBird]
    D --> D3[Sparse Transformer]

    E --> E1[RoPE]
    E --> E2[ALiBi]
    E --> E3[Sinusoidal]

    F --> F1[MQA]
    F --> F2[GQA]
    F --> F3[MLA]

1. Standard Attention Review

1.1 Scaled Dot-Product Attention

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V \]
  • Time complexity: \(O(n^2 d)\)
  • Space complexity: \(O(n^2 + nd)\)

where \(n\) is the sequence length and \(d\) is the feature dimension.

1.2 Multi-Head Attention (MHA)

\[ \text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \]
\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]

Per-head dimension: \(d_k = d_v = d_{\text{model}} / h\)

Parameter count: \(4 \times d^2\) (Q, K, V, O projection matrices)


2. FlashAttention: IO-Aware Exact Attention

2.1 Motivation

The bottleneck of standard attention is not computation but memory IO:

  • GPUs have massive compute but limited HBM bandwidth
  • The \(n^2\) attention matrix requires HBM reads/writes
  • Wall-clock time is dominated by HBM access

2.2 FlashAttention v1

Core Idea: Tiling (block computation) + Recomputation (recompute during backward pass)

Tiling Algorithm:

  1. Partition \(Q, K, V\) into blocks fitting in SRAM
  2. Compute local attention in SRAM
  3. Use online softmax trick to incrementally accumulate results
  4. Never write the \(n \times n\) attention matrix to HBM

Online Softmax:

\[ m_{\text{new}} = \max(m_{\text{old}}, \max(\mathbf{x}_{\text{block}})) \]
\[ \ell_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \ell_{\text{old}} + \sum_j e^{x_j - m_{\text{new}}} \]

Complexity Improvement:

Metric Standard Attention FlashAttention
HBM reads/writes \(O(n^2 d + n^2)\) \(O(n^2 d^2 / M)\)
FLOPS \(O(n^2 d)\) \(O(n^2 d)\) (unchanged)
Extra memory \(O(n^2)\) \(O(n)\)

where \(M\) is the SRAM size.

2.3 FlashAttention v2 & v3

FlashAttention-2:

  • Optimized thread block work partitioning (reduced non-matmul operations)
  • Parallelization along sequence length dimension
  • Achieves 50-73% of theoretical FLOPS on A100

FlashAttention-3:

  • Optimized for Hopper architecture (H100)
  • Leverages asynchronous copies and WGMMA instructions
  • FP8 support

3. Linear Attention and Low-Rank Approximations

3.1 Linformer

Idea: The attention matrix is low-rank and can be approximated with low-dimensional projections.

Project \(K\) and \(V\) from \(n \times d\) to \(k \times d\) (\(k \ll n\)):

\[ \text{Linformer}(Q, K, V) = \text{softmax}\left(\frac{Q(E_K K)^\top}{\sqrt{d_k}}\right)(E_V V) \]

where \(E_K, E_V \in \mathbb{R}^{k \times n}\) are learned projection matrices.

Complexity: \(O(nk)\) — linear in \(n\)

3.2 Performer

Idea: Approximate the softmax kernel with random features to avoid explicit \(n \times n\) matrix computation.

\[ \text{softmax}(QK^\top) \approx \phi(Q) \phi(K)^\top \]

where \(\phi\) is a random feature map (e.g., FAVOR+ with orthogonal random features).

Key Formula:

\[ \text{Attention} \approx \phi(Q) (\phi(K)^\top V) \]

Compute \(\phi(K)^\top V\) (a \(d \times d\) matrix) first, then multiply by \(\phi(Q)\).

Complexity: \(O(n d^2)\) — linear in \(n\)

3.3 Method Comparison

Method Complexity Exact? Requires Retraining?
Standard Attention \(O(n^2 d)\) Exact -
FlashAttention \(O(n^2 d)\) Exact No (drop-in replacement)
Linformer \(O(nkd)\) Approximate Yes
Performer \(O(nd^2)\) Approximate Yes

4. Sparse Attention

4.1 Longformer

Hybrid Attention Patterns:

  1. Sliding window attention: Each token attends to \(w\) local tokens
  2. Dilated sliding window: Like dilated convolutions, increasing receptive field
  3. Global attention: Special tokens (e.g., [CLS]) attend to all tokens
\[ \text{Complexity} = O(n \times w) \quad \text{(linear in } n\text{)} \]

4.2 BigBird

Combines three sparse patterns:

  1. Local attention (sliding window)
  2. Global attention (selected global tokens)
  3. Random attention (random connections)

Theoretical Guarantee: BigBird is Turing complete (random + global attention ensures this).

4.3 Sparse Transformer

Fixed Sparse Patterns:

  • Strided pattern: Row attention + column attention
  • Complexity: \(O(n\sqrt{n})\)

5. Positional Encoding

5.1 Why Positional Encoding Is Necessary

Transformer self-attention is permutation invariant, so positional information must be injected through positional encoding.

5.2 Rotary Position Embedding (RoPE)

RoPE (Su et al., 2021) is currently the most popular positional encoding in LLMs.

Core Idea: Encode relative positions through rotation matrices.

For query at position \(m\) and key at position \(n\):

\[ \langle f_q(x_m, m), f_k(x_n, n) \rangle = g(x_m, x_n, m-n) \]

2D Rotation (every two dimensions form a group):

\[ R_\Theta(m) = \begin{bmatrix} \cos m\theta_1 & -\sin m\theta_1 & & \\ \sin m\theta_1 & \cos m\theta_1 & & \\ & & \ddots & \\ & & & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ & & & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{bmatrix} \]

where \(\theta_i = 10000^{-2i/d}\).

Advantages:

  • Naturally encodes relative positions
  • Reasonable extrapolation (with NTK-aware scaling techniques)
  • Computationally efficient (element-wise operations)

5.3 ALiBi (Attention with Linear Biases)

ALiBi (Press et al., 2022) adds a linear bias to attention scores instead of using positional encoding:

\[ \text{softmax}(q_i^\top k_j - m \cdot |i - j|) \]

where \(m\) is a per-head slope (geometric sequence).

Advantages:

  • No additional model parameters
  • Natural length extrapolation
  • Minimal computational overhead

5.4 Positional Encoding Comparison

Method Type Extrapolation Overhead Used By
Sinusoidal Absolute Poor \(O(1)\) Original Transformer
Learned Absolute Poor \(O(1)\) BERT, GPT-2
RoPE Relative Moderate (extensible) \(O(1)\) LLaMA, Qwen, Gemma
ALiBi Relative bias Good \(O(1)\) BLOOM, MPT
T5 bias Relative Moderate \(O(1)\) T5

6. KV Cache Optimization

6.1 Background

During autoregressive generation, each new token must interact with all historical KV pairs. KV cache grows linearly with sequence length, becoming a memory bottleneck.

Standard MHA KV cache size: \(2 \times n_{\text{layers}} \times n_{\text{heads}} \times n_{\text{seq}} \times d_{\text{head}}\)

6.2 Multi-Query Attention (MQA)

MQA (Shazeer, 2019): All query heads share one set of K and V.

\[ \text{MQA}: h \text{ Q heads}, 1 \text{ K head}, 1 \text{ V head} \]

KV cache reduction: \(h\times\)

Trade-off: Slight quality degradation

6.3 Grouped-Query Attention (GQA)

GQA (Ainslie et al., 2023): A compromise where query heads are grouped, with each group sharing one KV set.

\[ \text{GQA-}g: h \text{ Q heads}, g \text{ K heads}, g \text{ V heads} \]
  • \(g=1\): Degenerates to MQA
  • \(g=h\): Degenerates to standard MHA

Typical setting: \(h=32, g=8\) (LLaMA 2 70B)

6.4 Multi-Head Latent Attention (MLA)

MLA (DeepSeek-V2): Compresses KV into a low-dimensional latent space.

\[ k_t = W^{UK} c_t^{KV}, \quad v_t = W^{UV} c_t^{KV} \]
\[ c_t^{KV} = W^{DKV} x_t \in \mathbb{R}^{d_c}, \quad d_c \ll d \]

Only \(c_t^{KV}\) is cached, dramatically reducing KV cache size.

6.5 Summary Comparison

Method KV Heads Cache Size Quality Representative Models
MHA \(h\) \(2nhd\) Best GPT-3
MQA 1 \(2nd\) Slightly worse PaLM
GQA \(g\) \(2ngd\) Near MHA LLaMA 2, Gemma
MLA - \(2nd_c\) Good DeepSeek-V2

7. Other Attention Variants

7.1 Sliding Window Attention

Used by Mistral, limiting attention to a fixed window:

\[ \text{Attention}(q_i) = \text{softmax}\left(\frac{q_i k_{[i-w:i]}^\top}{\sqrt{d}}\right) v_{[i-w:i]} \]

7.2 Ring Attention

Distributes long sequences across multiple devices, each processing a portion of KV, communicating KV blocks in a ring pattern.

7.3 Differential Attention

\[ \text{DiffAttn}(X) = (\text{softmax}(Q_1 K_1^\top) - \lambda \cdot \text{softmax}(Q_2 K_2^\top)) V \]

Reduces attention noise on irrelevant tokens.


8. Practical Selection Guide

graph TD
    A[Choose Attention Strategy] --> B{Sequence Length?}
    B -->|<8K| C[Standard MHA + FlashAttention]
    B -->|8K-128K| D[GQA + FlashAttention + RoPE]
    B -->|>128K| E{Precision Requirements?}
    E -->|High| F[Ring Attention + GQA]
    E -->|Approximate OK| G[Sparse/Linear Attention]

    D --> H{Inference Optimization?}
    H -->|Yes| I[MQA/GQA + KV Cache Compression]
    H -->|No| J[Standard GQA]

9. Summary

Optimization Direction Method Core Idea
IO Optimization FlashAttention Tiled computation, reduced HBM access
Low-Rank Approximation Linformer, Performer Compress K/V or kernel approximation
Sparse Patterns Longformer, BigBird Local + global + random
Positional Encoding RoPE, ALiBi Relative position, length extrapolation
KV Cache MQA, GQA, MLA Reduce KV heads or compress

Current Mainstream Configuration (2024-2025):

  • Positional encoding: RoPE (+ YaRN/NTK for length extension)
  • Attention: GQA + FlashAttention-2/3
  • Inference: PagedAttention + quantized KV cache

References

  • Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," NeurIPS 2022
  • Wang et al., "Linformer: Self-Attention with Linear Complexity," 2020
  • Choromanski et al., "Rethinking Attention with Performers," ICLR 2021
  • Beltagy et al., "Longformer: The Long-Document Transformer," 2020
  • Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding," 2021
  • Press et al., "Train Short, Test Long: Attention with Linear Biases," ICLR 2022
  • Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models," EMNLP 2023

评论 #