KV Cache 与长上下文推理
KV Cache 是 LLM 推理加速的核心技术。随着模型处理上下文窗口从 2K 扩展到 128K 甚至百万 token,如何高效管理 KV Cache 以及支撑长上下文推理成为推理工程的关键挑战。
关于 KV Cache 的基本原理和 PagedAttention,参见 vLLM。
KV Cache 回顾
为什么需要 KV Cache
在 Transformer 的自回归生成中,每生成一个新 token 都需要对之前所有 token 做 Attention 计算。KV Cache 将已计算的 Key 和 Value 缓存起来,避免重复计算:
每次只需计算新 token 的 \(q_t, k_t, v_t\),然后将 \(k_t, v_t\) 追加到缓存中。
显存占用估算
对于一个有 \(L\) 层、\(H\) 个注意力头、每头维度 \(d\) 的模型,序列长度为 \(s\),KV Cache 的显存为:
| 模型 | 参数量 | 序列长度 | KV Cache (FP16) |
|---|---|---|---|
| LLaMA-2-7B | 7B | 4K | ~1 GB |
| LLaMA-2-13B | 13B | 4K | ~1.6 GB |
| LLaMA-3-70B | 70B | 8K | ~10 GB |
| LLaMA-3-70B | 70B | 128K | ~160 GB |
可以看到,随着上下文长度增长,KV Cache 的显存占用线性增长,在长上下文场景下甚至远超模型权重本身的显存。
长上下文的核心挑战
显存瓶颈
标准 Attention 的 KV Cache 与序列长度成线性关系,但多并发请求下显存总量爆炸。例如,70B 模型在 128K 上下文下,单个请求的 KV Cache 就需要 ~160 GB,这使得服务化部署极其困难。
注意力计算复杂度
标准 Self-Attention 的计算复杂度为 \(O(n^2)\),当序列长度达到 128K 时,Attention 计算本身也成为瓶颈。
位置编码外推
传统的绝对位置编码在超出训练长度时会失效,需要能够外推到更长序列的位置编码方案。
位置编码与长上下文支持
RoPE (Rotary Position Embedding)
RoPE 是当前主流 LLM(LLaMA、Qwen、Mistral 等)采用的位置编码方式。它将位置信息编码为旋转矩阵,使得注意力分数自然地依赖于 token 间的相对距离:
其中 \(\theta_j = 10000^{-2j/d}\) 是频率基底。
长上下文扩展方法:
- NTK-aware Scaling:调整频率基底 \(\theta' = \theta \cdot \alpha^{d/(d-2)}\),使低频分量被压缩以覆盖更长范围,同时高频分量保持不变
- YaRN (Yet another RoPE extensioN):结合 NTK Scaling 与注意力分布修正,在扩展上下文的同时保持注意力分布的形状
- Dynamic NTK:根据实际输入长度动态调整缩放因子,短序列不受影响
ALiBi (Attention with Linear Biases)
ALiBi 不使用位置编码,而是在 Attention Score 上加一个线性偏置惩罚:
其中 \(m\) 是每个 Attention Head 特有的斜率。距离越远,惩罚越大。这种设计天然支持长度外推,因为它不需要学习位置编码,模式在任意长度上都是一致的。
KV Cache 压缩技术
Multi-Query Attention (MQA) 与 Grouped-Query Attention (GQA)
传统 Multi-Head Attention 中,每个 Head 都有独立的 K、V。MQA 和 GQA 通过共享 K、V 来减少 KV Cache:
| 方式 | K/V Head 数 | KV Cache 压缩比 | 代表模型 |
|---|---|---|---|
| MHA (标准) | \(H\) | 1x | GPT-3 |
| GQA | \(G\) (\(1 < G < H\)) | \(H/G\) x | LLaMA-2-70B, LLaMA-3 |
| MQA | 1 | \(H\) x | PaLM, Falcon |
GQA 是目前的主流选择,在几乎不损失质量的情况下将 KV Cache 压缩数倍。例如 LLaMA-3 使用 8 组 GQA(\(H=32, G=8\)),KV Cache 缩小为原来的 1/4。
KV Cache 量化
将 KV Cache 从 FP16 量化为 INT8 或 INT4,可进一步压缩显存。研究表明 KV Cache 中 Key 的分布通常比 Value 更均匀,因此 Key 更适合低比特量化:
- KV Cache INT8 量化:显存减半,精度损失极小
- KV Cache INT4 量化:显存缩至 1/4,长上下文场景下部分任务有可感知的精度下降
vLLM、TensorRT-LLM 等主流框架已支持 FP8 KV Cache。
Token Eviction 与 Sparse Attention
不是所有 token 的 KV 都同等重要。部分方法通过识别"不重要"的 token 并驱逐其 KV 来压缩缓存:
- H2O (Heavy-Hitter Oracle):保留 Attention Score 累积最高的 token("重击者")和最近的若干 token,驱逐其余
- StreamingLLM:只保留前几个 "attention sink" token 和一个滑动窗口内的 token,支持理论上无限长的流式推理
- Sliding Window Attention:每个 token 只关注最近 \(W\) 个 token(Mistral 使用 \(W=4096\)),KV Cache 大小固定为 \(W\)
长上下文推理的工程实践
Prefill 与 Decode 分离
长上下文推理的一个关键优化是将 Prefill(处理 prompt)和 Decode(逐 token 生成)阶段分离:
- Prefill 阶段:计算密集,适合高并行。处理完整 prompt,一次性计算所有 token 的 KV Cache
- Decode 阶段:访存密集,每步只生成一个 token
分离调度(Disaggregated Serving)允许用不同硬件或不同批处理策略分别优化两个阶段。
Prefix Caching
当多个请求共享相同前缀(如系统提示词)时,可以缓存并复用前缀的 KV Cache:
请求A: [System Prompt] + [用户问题A]
请求B: [System Prompt] + [用户问题B]
↑
共享的 KV Cache,只需计算一次
vLLM 的 Automatic Prefix Caching 和 SGLang 的 RadixAttention 都实现了这一机制。
Chunked Prefill
对于超长 prompt,一次性 Prefill 可能导致显存溢出或长时间占用 GPU。Chunked Prefill 将 prompt 切分为较小的 chunk,分批计算:
- 避免单次 Prefill 的显存峰值
- 允许在 chunk 间穿插其他请求的 Decode 步骤,降低延迟
Flash Attention
Flash Attention 通过分块计算和减少 HBM 访问来加速 Attention,是长上下文推理的必备优化:
- Flash Attention 2:优化了并行策略和 warp 分配
- Flash Attention 3:针对 Hopper 架构(H100)优化,利用异步执行和 FP8 支持
Flash Attention 不改变 KV Cache 的大小,但大幅加速了 Attention 计算本身,使得更长的上下文在计算上变得可行。
主流框架的长上下文支持
| 框架 | Prefix Caching | KV Cache 量化 | Chunked Prefill | Flash Attention |
|---|---|---|---|---|
| vLLM | 自动前缀缓存 | FP8 | 支持 | FA2 |
| TensorRT-LLM | 支持 | FP8, INT8 | 支持 | FA2/FA3 |
| SGLang | RadixAttention | FP8 | 支持 | FA2 |
| DeepSpeed-MII | 支持 | 支持 | 支持 | FA2 |
参考
- Kwon et al., "Efficient Memory Management for Large Language Model Serving with PagedAttention", SOSP 2023
- Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints", EMNLP 2023
- Xiao et al., "Efficient Streaming Language Models with Attention Sinks", ICLR 2024
- Dao et al., "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning", ICLR 2024
- Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding", Neurocomputing 2024