DiT (Diffusion Transformer)
论文: Scalable Diffusion Models with Transformers (Peebles & Xie, 2023)
DiT 提出用 Transformer 替代 U-Net 作为 Diffusion 模型的去噪骨干网络,证明了 Transformer 在图像生成任务上同样具有优秀的 scaling 特性。这一工作深刻影响了后续的 Sora、SD3、FLUX 等模型的架构选择。
1. 背景与动机
1.1 Diffusion 模型的成功
自 DDPM (Ho et al., 2020) 以来,扩散模型在图像生成领域取得了巨大成功。其核心思想是:
- 前向过程:逐步向图像添加高斯噪声,直到变成纯噪声
- 反向过程:训练一个神经网络学习逐步去噪,从纯噪声还原出图像
数学上,前向过程定义为:
而反向过程则由神经网络 \(\epsilon_\theta(x_t, t)\) 参数化,预测加入的噪声。
Stable Diffusion 进一步将扩散过程放到 潜空间 (latent space) 中进行,即先用 VAE 将图像编码为低维 latent,再在 latent 上做扩散,大幅降低计算成本。
1.2 U-Net 的局限性
在 DiT 之前,几乎所有主流 Diffusion 模型(DDPM、Stable Diffusion、DALL-E 2、Imagen 等)都使用 U-Net 作为去噪网络。U-Net 存在以下局限:
- 架构特异性强:U-Net 是专门为像素级预测设计的架构(最早用于医学图像分割),其 encoder-decoder + skip connection 的结构并非通用设计
- 难以 scale:U-Net 的 scaling 行为不够清晰,不像 Transformer 那样有明确的"加大模型 = 提升性能"的 scaling law
- 缺乏统一性:NLP 已经统一到 Transformer 架构,而视觉/生成领域仍在使用各种特定架构
1.3 Transformer 的 scaling 优势
ViT (Dosovitskiy et al., 2020) 已经证明:纯 Transformer 架构在视觉任务上可以匹敌甚至超越 CNN,前提是有足够的数据和计算。Transformer 的核心优势在于:
- 清晰的 scaling behavior:模型参数量、数据量、计算量的增加能可预测地带来性能提升
- 架构通用性:同一架构可以处理文本、图像、音频、视频等多种模态
- 成熟的工程生态:大量针对 Transformer 的优化技术(FlashAttention、模型并行等)
1.4 DiT 的核心问题
DiT 论文提出一个直接的问题:
核心问题
能否用 Transformer 替代 U-Net,作为 Diffusion 模型的去噪骨干网络 (backbone)?如果可以,其 scaling 表现如何?
答案是肯定的。DiT 证明了 Transformer-based 的 Diffusion 模型不仅可行,而且具有极好的 scaling 特性 -- 更大的模型 + 更多的计算 = 更好的生成质量 (更低的 FID)。
2. DiT 架构详解
2.1 整体架构
DiT 的整体流程可以用以下架构图概括:
DiT 整体架构
================================================================
输入 输出
---- ----
Noisy Latent z_t 条件信息 Predicted
(32 x 32 x 4) (timestep t, class y) Noise / x_0
| | ^
v v |
+---------+ +-----------+ +------------+
| Patchify| | Embedding | | Unpatchify |
| + Pos | | (t_emb + | | (Linear + |
| Embed | | y_emb) | | Reshape) |
+---------+ +-----------+ +------------+
| | ^
v v |
+----------+-----------+ |
| |
v |
+--------------------+ |
| DiT Block #1 | |
+--------------------+ |
| |
v |
+--------------------+ |
| DiT Block #2 | |
... |
+--------------------+ |
| |
v |
+--------------------+ |
| DiT Block #N |-------------------------+
+--------------------+
================================================================
2.2 Patchify:将 Latent 切成 Patches
与 ViT 处理图像的方式一致,DiT 首先将输入的 noisy latent \(z_t\) 切成互不重叠的 patches:
- 输入 latent 尺寸:\(H \times W \times C\)(在 ImageNet 256x256 实验中为 \(32 \times 32 \times 4\))
- Patch 大小 \(p\):论文实验了 \(p = 2, 4, 8\)
- Patch 数量:\(T = \frac{H \times W}{p^2}\)
- 每个 patch 通过线性投影映射为 \(d\) 维向量
例如,当 \(p = 2\) 时,\(32 \times 32\) 的 latent 被切成 \(16 \times 16 = 256\) 个 patches。
Patch 大小的影响
更小的 patch 意味着更多的 token,计算量更大但信息保留更完整。论文发现 \(p = 2\) 效果最好,因为它保留了最多的空间信息。模型命名规则为 DiT-{Size}/{Patch},如 DiT-XL/2 表示 XL 尺寸模型 + patch 大小为 2。
Patchify 之后,加上标准的位置编码 (positional embedding),使模型能够感知空间位置信息。
2.3 条件嵌入
DiT 需要注入两种条件信息:
时间步嵌入 (Timestep Embedding):
- 使用与 DDPM 相同的 sinusoidal 位置编码将标量 \(t\) 映射为向量
- 再通过 MLP 投影到隐藏维度:\(t_{\text{emb}} = \text{MLP}(\text{sinusoidal}(t))\)
类别嵌入 (Class Label Embedding):
- 使用可学习的 embedding table 将类别标签 \(y\) 映射为向量
- \(y_{\text{emb}} = \text{Embedding}(y)\)
最终条件向量:\(c = t_{\text{emb}} + y_{\text{emb}}\)
2.4 DiT Block:Transformer Block + 条件注入
DiT Block 是标准 Transformer Block 加上条件信息注入机制。论文系统地对比了四种条件注入方式,这是论文的核心实验之一。
方式一:In-Context Conditioning
将 \(t_{\text{emb}}\) 和 \(y_{\text{emb}}\) 视为两个额外的 token,与 patch tokens 拼接后一起输入标准 Transformer block:
- 优点:实现简单,无需修改 Transformer 架构
- 缺点:条件信息与 patch tokens 的交互依赖 self-attention 自行学习,效率较低
方式二:Cross-Attention
在每个 Transformer block 中加入一个 cross-attention 层,patch tokens 作为 Query,条件信息作为 Key/Value:
- 优点:这是 Stable Diffusion U-Net 中注入文本条件的方式,已被验证有效
- 缺点:增加额外参数和计算量
方式三:Adaptive LayerNorm (adaLN)
用条件向量 \(c\) 回归 LayerNorm 的 scale (\(\gamma\)) 和 shift (\(\beta\)) 参数:
- 优点:参数高效,不增加额外 attention 层
- 缺点:条件信息只通过归一化层注入,表达力可能受限
方式四:adaLN-Zero(最优方案)
在 adaLN 基础上,额外引入一个 scale 参数 \(\alpha\),并将所有 \(\gamma, \beta, \alpha\) 初始化为使每个 DiT block 在初始时等价于恒等函数。这是 DiT 论文最终采用的方案。
核心结论
论文实验结果表明,四种条件注入方式的性能排序为:adaLN-Zero > adaLN > Cross-Attention > In-Context。adaLN-Zero 在所有模型尺寸上都取得了最优 FID。
2.5 Unpatchify:从 Tokens 还原 Latent
DiT 的最后一步是将 Transformer 输出的 token 序列还原为与输入相同尺寸的 latent:
- 最终 LayerNorm(由 adaLN-Zero 调制)
- 线性投影:将每个 token 从 \(d\) 维映射到 \(p \times p \times 2C\) 维(预测噪声和对角协方差)
- Reshape:将 token 序列重新排列为空间结构 \(H \times W \times 2C\)
3. adaLN-Zero 详解
adaLN-Zero 是 DiT 的关键创新,值得深入理解。
3.1 从标准 LayerNorm 到 adaLN-Zero
标准 LayerNorm:
其中 \(\gamma\) 和 \(\beta\) 是可学习参数,与输入无关。
Adaptive LayerNorm (adaLN):
\(\gamma\) 和 \(\beta\) 由条件向量 \(c\) 动态生成,实现了条件依赖的归一化。
adaLN-Zero:
在 adaLN 基础上,每个残差连接前额外引入一个 scale 参数 \(\alpha\):
其中 \(\alpha\) 同样由条件 \(c\) 回归得到。
3.2 零初始化的关键
adaLN-Zero 的"Zero"体现在初始化策略上:
- MLP 回归 \(\alpha\) 的最后一层权重初始化为 全零
- 这意味着在训练初始时,\(\alpha = 0\)
- 因此每个 DiT block 的输出为:\(y = x + 0 \cdot \text{Block}(\cdot) = x\)
adaLN-Zero 初始化时的等效行为:
输入 x ──────────────────────────────── 输出 x
| ^
v | (alpha = 0, 被屏蔽)
+----------+ +----------+ |
| adaLN |--->| Attention |--x--+
| | | or FFN | ^
+----------+ +----------+ |
^ alpha = 0
|
条件 c ──> MLP ──> (gamma, beta, alpha)
全部初始化为合理值
alpha 初始化为 0
为什么零初始化有效?
这一策略借鉴了 ResNet 中的零初始化残差思想。核心好处是:
- 训练稳定性:模型初始时等价于恒等映射,梯度可以无损地流过整个网络,避免深层网络的训练不稳定问题
- 渐进式学习:每个 block 从"什么都不做"开始,逐渐学习有意义的变换,训练过程更加平滑
- 可以堆叠更深:由于初始时不会破坏信号,即使堆叠很多 block 也不会导致训练崩溃
3.3 adaLN-Zero 的完整计算流程
一个 DiT Block 内部的完整计算如下:
DiT Block (adaLN-Zero) 内部结构
================================================================
条件向量 c ──> MLP ──> (gamma_1, beta_1, alpha_1,
gamma_2, beta_2, alpha_2)
输入 x
|
v
adaLN(x, gamma_1, beta_1)
|
v
Multi-Head Self-Attention
|
v
x alpha_1 (逐元素乘以 scale)
|
v
+ x (残差连接) ──> x'
|
v
adaLN(x', gamma_2, beta_2)
|
v
Pointwise FeedForward (MLP)
|
v
x alpha_2 (逐元素乘以 scale)
|
v
+ x' (残差连接) ──> 输出
================================================================
每个 block 从条件 \(c\) 回归出 6 个向量:attention 分支的 \((\gamma_1, \beta_1, \alpha_1)\) 和 FFN 分支的 \((\gamma_2, \beta_2, \alpha_2)\)。
4. Latent DiT 工作流
DiT 与 Latent Diffusion Model (LDM) 的结合方式非常直接 -- 将 U-Net 替换为 DiT,其余保持不变。
4.1 训练流程
训练阶段
================================================================
真实图像 x_0 时间步 t 类别标签 y
| | |
v | |
+----------+ | |
| VAE | v v
| Encoder | sinusoidal + Embedding
+----------+ MLP = t_emb Table = y_emb
| | |
v +-------+--------+
Latent z_0 |
| c = t_emb + y_emb
v |
加噪声 (q(z_t|z_0)) |
z_t = sqrt(a_t)*z_0 |
+ sqrt(1-a_t)*eps |
| |
v v
+-------------------------------+
| DiT(z_t, c) |
| 预测噪声 eps_theta 或 x_0 |
+-------------------------------+
|
v
Loss = ||eps - eps_theta||^2
================================================================
4.2 推理流程
推理阶段
================================================================
随机噪声 z_T ~ N(0, I) 条件信息 (t, y)
| |
v v
+-------------------------------+
| DiT(z_t, c) | x N 步迭代
| 预测噪声 eps_theta |<---+
+-------------------------------+ |
| |
v |
去噪一步: z_{t-1} ----------------+
|
v (当 t=0)
干净 latent z_0
|
v
+----------+
| VAE |
| Decoder |
+----------+
|
v
生成图像 x_0
================================================================
4.3 与 Stable Diffusion 的对比
架构对比
DiT 与 Stable Diffusion 的区别仅在去噪网络部分:
| 组件 | Stable Diffusion | Latent DiT |
|---|---|---|
| 图像编码器 | VAE Encoder | VAE Encoder(相同) |
| 潜空间 | 32x32x4 latent | 32x32x4 latent(相同) |
| 去噪网络 | U-Net | DiT (Transformer) |
| 噪声调度 | DDPM / DDIM | DDPM / DDIM(相同) |
| 图像解码器 | VAE Decoder | VAE Decoder(相同) |
可以说,DiT 是 Stable Diffusion 的"drop-in replacement" -- 换了一个去噪网络,其他一切不变。
5. Scaling 实验结果
DiT 论文的一大贡献是系统地研究了 Transformer-based Diffusion 模型的 scaling 行为。
5.1 模型配置
论文设计了四个不同大小的模型,命名规则为 DiT-{Size}/{Patch}:
| 模型 | 隐藏维度 \(d\) | 层数 \(N\) | Attention Heads | 参数量 (p=2) | Gflops (p=2) |
|---|---|---|---|---|---|
| DiT-S | 384 | 12 | 6 | 33M | 6.06 |
| DiT-B | 768 | 12 | 12 | 130M | 23.0 |
| DiT-L | 1024 | 24 | 16 | 458M | 80.7 |
| DiT-XL | 1152 | 28 | 16 | 675M | 119 |
5.2 ImageNet 256x256 生成结果
在 class-conditional ImageNet 256x256 生成任务上(使用 classifier-free guidance):
| 模型 | FID-50K | IS |
|---|---|---|
| DiT-S/2 | 68.4 | 27.1 |
| DiT-B/2 | 43.5 | 44.1 |
| DiT-L/2 | 23.3 | 76.3 |
| DiT-XL/2 | 9.62 | 121.5 |
| DiT-XL/2 (长时间训练) | 2.27 | 278.2 |
关键发现
- 模型越大,FID 越低:从 DiT-S 到 DiT-XL,FID 单调下降,且下降幅度显著
- Patch 越小,效果越好:\(p = 2\) 优于 \(p = 4\) 优于 \(p = 8\),因为更小的 patch 保留更多空间信息
- Scaling 曲线平滑:FID 随计算量 (Gflops) 的增加呈现非常平滑的下降趋势
- adaLN-Zero 在所有尺寸上最优:条件注入方式的优劣排序在不同模型大小上保持一致
5.3 Scaling 曲线
FID-50K vs. 计算量 (Gflops) [示意]
================================================================
FID
^
70 | * DiT-S/2
|
60 |
|
50 |
| * DiT-B/2
40 |
|
30 |
| * DiT-L/2
20 |
|
10 | * DiT-XL/2
|
2 | * DiT-XL/2 (长训练)
+-----+-----+------+------+------> Gflops
6 23 81 119
================================================================
这条曲线是 DiT 论文最核心的结果之一:它表明 Diffusion Transformer 具有与 LLM 类似的 scaling law -- 可预测地通过增加计算获得性能提升。
5.4 与其他模型的对比
DiT-XL/2 在充分训练后,FID 达到 2.27,与当时最好的 Diffusion 模型(如 ADM、LDM)相当或更优:
| 模型 | FID-50K | 骨干网络 |
|---|---|---|
| ADM (Dhariwal & Nichol) | 10.94 | U-Net |
| ADM + Classifier Guidance | 4.59 | U-Net |
| LDM-4 (Rombach et al.) | 10.56 | U-Net |
| DiT-XL/2 | 2.27 | Transformer |
6. DiT 的后续影响
DiT 的出现标志着 Diffusion 模型从"U-Net 时代"向"Transformer 时代"的转变。其后续影响深远。
6.1 Sora (OpenAI, 2024)
OpenAI 的 Sora 视频生成模型据传基于 DiT 架构的扩展,将 DiT 从 2D 图像扩展到 3D 时空 (spacetime) patches。Sora 的技术报告中明确引用了 DiT,并描述了将视频视为"spacetime patches"序列的思路。
6.2 SD3 / FLUX: MM-DiT
Stable Diffusion 3 和 FLUX 系列采用了 MM-DiT (Multi-Modal DiT) 架构:
- 双流 (dual-stream) 设计:图像 tokens 和文本 tokens 各有独立的 Transformer 流
- 在 attention 层进行交互:两个流的 tokens 拼接后做 joint attention
- 比原始 DiT 的 cross-attention 更加灵活
6.3 PixArt-alpha
PixArt-alpha 关注 DiT 的高效训练:
- 使用预训练的 T5 文本编码器提供文本条件
- 分阶段训练策略,降低训练成本
- 证明 DiT 架构可以在较少计算资源下也能训练出高质量模型
6.4 SiT (Scalable Interpolant Transformers)
SiT 保留 DiT 的 Transformer 架构,但将底层的 DDPM 框架替换为 stochastic interpolant / flow matching 框架,获得了进一步的性能提升。
6.5 其他
- Hunyuan-DiT (腾讯):中文理解能力强的文生图 DiT
- Open-Sora:开源的基于 DiT 的视频生成方案
- Latte:将 DiT 扩展到视频生成的早期工作
7. DiT vs U-Net 对比
| 维度 | U-Net | DiT (Transformer) |
|---|---|---|
| 架构来源 | 医学图像分割 (2015) | 视觉识别 ViT (2020) |
| 核心结构 | Encoder-Decoder + Skip Connection | 堆叠 Transformer Blocks |
| Scaling 行为 | 不够清晰,难以预测 | 平滑且可预测,类似 LLM |
| 条件注入 | Cross-Attention + Time Embed | adaLN-Zero(更高效) |
| 多分辨率特征 | 天然支持(U 形结构) | 需要额外设计 |
| 计算效率 | 中等 | 高(可利用 FlashAttention 等) |
| 工程生态 | 成熟但特定 | 可复用 LLM 训练基础设施 |
| 参数效率 | 较低 | 较高(参数利用率更高) |
| 视频扩展 | 需要 3D 卷积改造 | 自然扩展到时空 tokens |
| 多模态扩展 | 困难 | 自然(token 拼接即可) |
| 归纳偏置 | 强(局部性、多尺度) | 弱(数据驱动) |
| 小数据表现 | 较好 | 需要更多数据 |
趋势判断
从 2023 年开始,新提出的 SOTA 生成模型几乎全部转向 Transformer 架构。U-Net 在 Diffusion 领域的主导地位已经让位于 DiT 及其变体。这一趋势与 ViT 取代 CNN 成为视觉骨干的趋势一脉相承。
8. 思考与讨论
8.1 为什么 Transformer 能替代 U-Net?
表面上看,U-Net 的多尺度特征和 skip connection 似乎是图像生成的必要结构。但 DiT 的成功揭示了几个深层原因:
- Attention 可以学到多尺度特征:虽然 Transformer 没有显式的多尺度结构,但 self-attention 可以自适应地关注不同距离的 token,隐式地实现多尺度建模
- 足够的模型容量可以弥补归纳偏置的缺失:与 ViT vs CNN 的故事一致,当模型足够大、数据足够多时,强归纳偏置不再是必要条件
- Latent space 降低了分辨率需求:在 latent space(如 32x32)而非 pixel space(如 256x256)中工作,大幅降低了 token 数量,使 Transformer 的 \(O(n^2)\) 复杂度可以接受
8.2 DiT 对视频生成的意义
DiT 对视频生成的推动意义尤为深远:
- 自然的时空扩展:将 2D patch 扩展为 3D spacetime patch,Transformer 天然支持处理变长序列
- 统一的训练框架:同一架构可以同时处理图像和视频,只需调整 patch 的切割方式
- Scaling 的可行性:视频生成需要更大的模型,而 DiT 的 scaling law 保证了加大模型是有回报的
Sora 的出现正是这一路线的最佳验证。
8.3 DiT 是否是"统一架构"的又一例证?
从更宏观的视角看,DiT 是 Transformer 作为"统一架构"征服又一个领域的案例:
Transformer 的领域扩张路线
================================================================
2017 NLP (原始 Transformer)
|
2018 NLP 预训练 (BERT, GPT)
|
2020 视觉识别 (ViT)
|
2021 多模态理解 (CLIP)
|
2023 图像生成 (DiT) <--- 我们在这里
|
2024 视频生成 (Sora)
|
2024+ 音频、3D、机器人控制...
================================================================
这一趋势背后的深层逻辑是:Transformer 的通用性和 scaling 能力使其成为"计算的通用基底"。不同领域的任务差异通过 tokenization 方式(文本 token、image patch、video patch、audio patch)来处理,而核心计算引擎保持一致。
值得警惕的问题
统一架构并不意味着最优架构。Transformer 的 \(O(n^2)\) attention 复杂度在超长序列(如高分辨率视频)上仍然是瓶颈。线性 attention、state space models (Mamba) 等替代方案仍在发展中。DiT 的成功不应被过度解读为"Transformer 永远是最优选择",而应理解为"在当前的计算范式和数据规模下,Transformer 是最具 scaling 潜力的架构"。
9. 总结
DiT 的核心贡献可以归纳为三点:
- 验证了 Transformer 可以替代 U-Net 作为 Diffusion 模型的去噪骨干,且性能相当或更优
- 提出了 adaLN-Zero 这一高效的条件注入方式,通过零初始化实现稳定训练
- 系统地研究了 scaling 行为,证明 Diffusion Transformer 具有与 LLM 类似的可预测 scaling law
DiT 的影响远超论文本身。它开启了生成式 AI 从"特定架构"向"统一架构"的转变,为 Sora、SD3 等后续工作奠定了基础,也为"Transformer 是通用计算架构"这一观点提供了又一个有力证据。