Transformer Variants and Efficient Attention
Overview
The standard Transformer's self-attention mechanism has \(O(n^2)\) complexity, limiting its ability to handle long sequences. This chapter systematically covers efficient attention variants including IO-aware optimization (FlashAttention), linear attention, sparse attention, and the evolution of positional encoding.
graph TD
A[Attention Variants] --> B[IO-Aware Optimization]
A --> C[Linear/Low-Rank Approximation]
A --> D[Sparse Attention]
A --> E[Positional Encoding]
A --> F[KV Cache Optimization]
B --> B1[FlashAttention]
B --> B2[FlashAttention-2/3]
C --> C1[Linformer]
C --> C2[Performer]
C --> C3[Linear Attention]
D --> D1[Longformer]
D --> D2[BigBird]
D --> D3[Sparse Transformer]
E --> E1[RoPE]
E --> E2[ALiBi]
E --> E3[Sinusoidal]
F --> F1[MQA]
F --> F2[GQA]
F --> F3[MLA]
1. Standard Attention Review
1.1 Scaled Dot-Product Attention
- Time complexity: \(O(n^2 d)\)
- Space complexity: \(O(n^2 + nd)\)
where \(n\) is the sequence length and \(d\) is the feature dimension.
1.2 Multi-Head Attention (MHA)
Per-head dimension: \(d_k = d_v = d_{\text{model}} / h\)
Parameter count: \(4 \times d^2\) (Q, K, V, O projection matrices)
2. FlashAttention: IO-Aware Exact Attention
2.1 Motivation
The bottleneck of standard attention is not computation but memory IO:
- GPUs have massive compute but limited HBM bandwidth
- The \(n^2\) attention matrix requires HBM reads/writes
- Wall-clock time is dominated by HBM access
2.2 FlashAttention v1
Core Idea: Tiling (block computation) + Recomputation (recompute during backward pass)
Tiling Algorithm:
- Partition \(Q, K, V\) into blocks fitting in SRAM
- Compute local attention in SRAM
- Use online softmax trick to incrementally accumulate results
- Never write the \(n \times n\) attention matrix to HBM
Online Softmax:
Complexity Improvement:
| Metric | Standard Attention | FlashAttention |
|---|---|---|
| HBM reads/writes | \(O(n^2 d + n^2)\) | \(O(n^2 d^2 / M)\) |
| FLOPS | \(O(n^2 d)\) | \(O(n^2 d)\) (unchanged) |
| Extra memory | \(O(n^2)\) | \(O(n)\) |
where \(M\) is the SRAM size.
2.3 FlashAttention v2 & v3
FlashAttention-2:
- Optimized thread block work partitioning (reduced non-matmul operations)
- Parallelization along sequence length dimension
- Achieves 50-73% of theoretical FLOPS on A100
FlashAttention-3:
- Optimized for Hopper architecture (H100)
- Leverages asynchronous copies and WGMMA instructions
- FP8 support
3. Linear Attention and Low-Rank Approximations
3.1 Linformer
Idea: The attention matrix is low-rank and can be approximated with low-dimensional projections.
Project \(K\) and \(V\) from \(n \times d\) to \(k \times d\) (\(k \ll n\)):
where \(E_K, E_V \in \mathbb{R}^{k \times n}\) are learned projection matrices.
Complexity: \(O(nk)\) — linear in \(n\)
3.2 Performer
Idea: Approximate the softmax kernel with random features to avoid explicit \(n \times n\) matrix computation.
where \(\phi\) is a random feature map (e.g., FAVOR+ with orthogonal random features).
Key Formula:
Compute \(\phi(K)^\top V\) (a \(d \times d\) matrix) first, then multiply by \(\phi(Q)\).
Complexity: \(O(n d^2)\) — linear in \(n\)
3.3 Method Comparison
| Method | Complexity | Exact? | Requires Retraining? |
|---|---|---|---|
| Standard Attention | \(O(n^2 d)\) | Exact | - |
| FlashAttention | \(O(n^2 d)\) | Exact | No (drop-in replacement) |
| Linformer | \(O(nkd)\) | Approximate | Yes |
| Performer | \(O(nd^2)\) | Approximate | Yes |
4. Sparse Attention
4.1 Longformer
Hybrid Attention Patterns:
- Sliding window attention: Each token attends to \(w\) local tokens
- Dilated sliding window: Like dilated convolutions, increasing receptive field
- Global attention: Special tokens (e.g., [CLS]) attend to all tokens
4.2 BigBird
Combines three sparse patterns:
- Local attention (sliding window)
- Global attention (selected global tokens)
- Random attention (random connections)
Theoretical Guarantee: BigBird is Turing complete (random + global attention ensures this).
4.3 Sparse Transformer
Fixed Sparse Patterns:
- Strided pattern: Row attention + column attention
- Complexity: \(O(n\sqrt{n})\)
5. Positional Encoding
5.1 Why Positional Encoding Is Necessary
Transformer self-attention is permutation invariant, so positional information must be injected through positional encoding.
5.2 Rotary Position Embedding (RoPE)
RoPE (Su et al., 2021) is currently the most popular positional encoding in LLMs.
Core Idea: Encode relative positions through rotation matrices.
For query at position \(m\) and key at position \(n\):
2D Rotation (every two dimensions form a group):
where \(\theta_i = 10000^{-2i/d}\).
Advantages:
- Naturally encodes relative positions
- Reasonable extrapolation (with NTK-aware scaling techniques)
- Computationally efficient (element-wise operations)
5.3 ALiBi (Attention with Linear Biases)
ALiBi (Press et al., 2022) adds a linear bias to attention scores instead of using positional encoding:
where \(m\) is a per-head slope (geometric sequence).
Advantages:
- No additional model parameters
- Natural length extrapolation
- Minimal computational overhead
5.4 Positional Encoding Comparison
| Method | Type | Extrapolation | Overhead | Used By |
|---|---|---|---|---|
| Sinusoidal | Absolute | Poor | \(O(1)\) | Original Transformer |
| Learned | Absolute | Poor | \(O(1)\) | BERT, GPT-2 |
| RoPE | Relative | Moderate (extensible) | \(O(1)\) | LLaMA, Qwen, Gemma |
| ALiBi | Relative bias | Good | \(O(1)\) | BLOOM, MPT |
| T5 bias | Relative | Moderate | \(O(1)\) | T5 |
6. KV Cache Optimization
6.1 Background
During autoregressive generation, each new token must interact with all historical KV pairs. KV cache grows linearly with sequence length, becoming a memory bottleneck.
Standard MHA KV cache size: \(2 \times n_{\text{layers}} \times n_{\text{heads}} \times n_{\text{seq}} \times d_{\text{head}}\)
6.2 Multi-Query Attention (MQA)
MQA (Shazeer, 2019): All query heads share one set of K and V.
KV cache reduction: \(h\times\)
Trade-off: Slight quality degradation
6.3 Grouped-Query Attention (GQA)
GQA (Ainslie et al., 2023): A compromise where query heads are grouped, with each group sharing one KV set.
- \(g=1\): Degenerates to MQA
- \(g=h\): Degenerates to standard MHA
Typical setting: \(h=32, g=8\) (LLaMA 2 70B)
6.4 Multi-Head Latent Attention (MLA)
MLA (DeepSeek-V2): Compresses KV into a low-dimensional latent space.
Only \(c_t^{KV}\) is cached, dramatically reducing KV cache size.
6.5 Summary Comparison
| Method | KV Heads | Cache Size | Quality | Representative Models |
|---|---|---|---|---|
| MHA | \(h\) | \(2nhd\) | Best | GPT-3 |
| MQA | 1 | \(2nd\) | Slightly worse | PaLM |
| GQA | \(g\) | \(2ngd\) | Near MHA | LLaMA 2, Gemma |
| MLA | - | \(2nd_c\) | Good | DeepSeek-V2 |
7. Other Attention Variants
7.1 Sliding Window Attention
Used by Mistral, limiting attention to a fixed window:
7.2 Ring Attention
Distributes long sequences across multiple devices, each processing a portion of KV, communicating KV blocks in a ring pattern.
7.3 Differential Attention
Reduces attention noise on irrelevant tokens.
8. Practical Selection Guide
graph TD
A[Choose Attention Strategy] --> B{Sequence Length?}
B -->|<8K| C[Standard MHA + FlashAttention]
B -->|8K-128K| D[GQA + FlashAttention + RoPE]
B -->|>128K| E{Precision Requirements?}
E -->|High| F[Ring Attention + GQA]
E -->|Approximate OK| G[Sparse/Linear Attention]
D --> H{Inference Optimization?}
H -->|Yes| I[MQA/GQA + KV Cache Compression]
H -->|No| J[Standard GQA]
9. Summary
| Optimization Direction | Method | Core Idea |
|---|---|---|
| IO Optimization | FlashAttention | Tiled computation, reduced HBM access |
| Low-Rank Approximation | Linformer, Performer | Compress K/V or kernel approximation |
| Sparse Patterns | Longformer, BigBird | Local + global + random |
| Positional Encoding | RoPE, ALiBi | Relative position, length extrapolation |
| KV Cache | MQA, GQA, MLA | Reduce KV heads or compress |
Current Mainstream Configuration (2024-2025):
- Positional encoding: RoPE (+ YaRN/NTK for length extension)
- Attention: GQA + FlashAttention-2/3
- Inference: PagedAttention + quantized KV cache
References
- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," NeurIPS 2022
- Wang et al., "Linformer: Self-Attention with Linear Complexity," 2020
- Choromanski et al., "Rethinking Attention with Performers," ICLR 2021
- Beltagy et al., "Longformer: The Long-Document Transformer," 2020
- Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding," 2021
- Press et al., "Train Short, Test Long: Attention with Linear Biases," ICLR 2022
- Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models," EMNLP 2023