Skip to content

DiT (Diffusion Transformer)

论文: Scalable Diffusion Models with Transformers (Peebles & Xie, 2023)

DiT 提出用 Transformer 替代 U-Net 作为 Diffusion 模型的去噪骨干网络,证明了 Transformer 在图像生成任务上同样具有优秀的 scaling 特性。这一工作深刻影响了后续的 Sora、SD3、FLUX 等模型的架构选择。


1. 背景与动机

1.1 Diffusion 模型的成功

自 DDPM (Ho et al., 2020) 以来,扩散模型在图像生成领域取得了巨大成功。其核心思想是:

  1. 前向过程:逐步向图像添加高斯噪声,直到变成纯噪声
  2. 反向过程:训练一个神经网络学习逐步去噪,从纯噪声还原出图像

数学上,前向过程定义为:

\[q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t \mathbf{I})\]

而反向过程则由神经网络 \(\epsilon_\theta(x_t, t)\) 参数化,预测加入的噪声。

Stable Diffusion 进一步将扩散过程放到 潜空间 (latent space) 中进行,即先用 VAE 将图像编码为低维 latent,再在 latent 上做扩散,大幅降低计算成本。

1.2 U-Net 的局限性

在 DiT 之前,几乎所有主流 Diffusion 模型(DDPM、Stable Diffusion、DALL-E 2、Imagen 等)都使用 U-Net 作为去噪网络。U-Net 存在以下局限:

  • 架构特异性强:U-Net 是专门为像素级预测设计的架构(最早用于医学图像分割),其 encoder-decoder + skip connection 的结构并非通用设计
  • 难以 scale:U-Net 的 scaling 行为不够清晰,不像 Transformer 那样有明确的"加大模型 = 提升性能"的 scaling law
  • 缺乏统一性:NLP 已经统一到 Transformer 架构,而视觉/生成领域仍在使用各种特定架构

1.3 Transformer 的 scaling 优势

ViT (Dosovitskiy et al., 2020) 已经证明:纯 Transformer 架构在视觉任务上可以匹敌甚至超越 CNN,前提是有足够的数据和计算。Transformer 的核心优势在于:

  • 清晰的 scaling behavior:模型参数量、数据量、计算量的增加能可预测地带来性能提升
  • 架构通用性:同一架构可以处理文本、图像、音频、视频等多种模态
  • 成熟的工程生态:大量针对 Transformer 的优化技术(FlashAttention、模型并行等)

1.4 DiT 的核心问题

DiT 论文提出一个直接的问题:

核心问题

能否用 Transformer 替代 U-Net,作为 Diffusion 模型的去噪骨干网络 (backbone)?如果可以,其 scaling 表现如何?

答案是肯定的。DiT 证明了 Transformer-based 的 Diffusion 模型不仅可行,而且具有极好的 scaling 特性 -- 更大的模型 + 更多的计算 = 更好的生成质量 (更低的 FID)。


2. DiT 架构详解

2.1 整体架构

DiT 的整体流程可以用以下架构图概括:

                          DiT 整体架构
================================================================

  输入                                              输出
  ----                                              ----
  Noisy Latent z_t          条件信息                 Predicted
  (32 x 32 x 4)         (timestep t, class y)       Noise / x_0
       |                      |                      ^
       v                      v                      |
  +---------+          +-----------+           +------------+
  | Patchify|          | Embedding |           | Unpatchify |
  | + Pos   |          | (t_emb +  |           | (Linear +  |
  | Embed   |          |  y_emb)   |           |  Reshape)  |
  +---------+          +-----------+           +------------+
       |                      |                      ^
       v                      v                      |
       +----------+-----------+                      |
                  |                                   |
                  v                                   |
       +--------------------+                         |
       |   DiT Block #1     |                         |
       +--------------------+                         |
                  |                                   |
                  v                                   |
       +--------------------+                         |
       |   DiT Block #2     |                         |
                 ...                                  |
       +--------------------+                         |
                  |                                   |
                  v                                   |
       +--------------------+                         |
       |   DiT Block #N     |-------------------------+
       +--------------------+

================================================================

2.2 Patchify:将 Latent 切成 Patches

与 ViT 处理图像的方式一致,DiT 首先将输入的 noisy latent \(z_t\) 切成互不重叠的 patches:

  • 输入 latent 尺寸:\(H \times W \times C\)(在 ImageNet 256x256 实验中为 \(32 \times 32 \times 4\)
  • Patch 大小 \(p\):论文实验了 \(p = 2, 4, 8\)
  • Patch 数量:\(T = \frac{H \times W}{p^2}\)
  • 每个 patch 通过线性投影映射为 \(d\) 维向量

例如,当 \(p = 2\) 时,\(32 \times 32\) 的 latent 被切成 \(16 \times 16 = 256\) 个 patches。

Patch 大小的影响

更小的 patch 意味着更多的 token,计算量更大但信息保留更完整。论文发现 \(p = 2\) 效果最好,因为它保留了最多的空间信息。模型命名规则为 DiT-{Size}/{Patch},如 DiT-XL/2 表示 XL 尺寸模型 + patch 大小为 2。

Patchify 之后,加上标准的位置编码 (positional embedding),使模型能够感知空间位置信息。

2.3 条件嵌入

DiT 需要注入两种条件信息:

时间步嵌入 (Timestep Embedding)

  • 使用与 DDPM 相同的 sinusoidal 位置编码将标量 \(t\) 映射为向量
  • 再通过 MLP 投影到隐藏维度:\(t_{\text{emb}} = \text{MLP}(\text{sinusoidal}(t))\)

类别嵌入 (Class Label Embedding)

  • 使用可学习的 embedding table 将类别标签 \(y\) 映射为向量
  • \(y_{\text{emb}} = \text{Embedding}(y)\)

最终条件向量:\(c = t_{\text{emb}} + y_{\text{emb}}\)

2.4 DiT Block:Transformer Block + 条件注入

DiT Block 是标准 Transformer Block 加上条件信息注入机制。论文系统地对比了四种条件注入方式,这是论文的核心实验之一。

方式一:In-Context Conditioning

\(t_{\text{emb}}\)\(y_{\text{emb}}\) 视为两个额外的 token,与 patch tokens 拼接后一起输入标准 Transformer block:

\[\text{Input} = [\underbrace{t_{\text{emb}}, y_{\text{emb}}}_{\text{条件 tokens}}, \underbrace{z_1, z_2, ..., z_T}_{\text{patch tokens}}]\]
  • 优点:实现简单,无需修改 Transformer 架构
  • 缺点:条件信息与 patch tokens 的交互依赖 self-attention 自行学习,效率较低

方式二:Cross-Attention

在每个 Transformer block 中加入一个 cross-attention 层,patch tokens 作为 Query,条件信息作为 Key/Value:

\[\text{CrossAttn}(Q = z_{\text{patches}}, K = V = c)\]
  • 优点:这是 Stable Diffusion U-Net 中注入文本条件的方式,已被验证有效
  • 缺点:增加额外参数和计算量

方式三:Adaptive LayerNorm (adaLN)

用条件向量 \(c\) 回归 LayerNorm 的 scale (\(\gamma\)) 和 shift (\(\beta\)) 参数:

\[\gamma, \beta = \text{MLP}(c)$$ $$\text{adaLN}(x, c) = \gamma(c) \cdot \frac{x - \mu}{\sigma} + \beta(c)\]
  • 优点:参数高效,不增加额外 attention 层
  • 缺点:条件信息只通过归一化层注入,表达力可能受限

方式四:adaLN-Zero(最优方案)

在 adaLN 基础上,额外引入一个 scale 参数 \(\alpha\),并将所有 \(\gamma, \beta, \alpha\) 初始化为使每个 DiT block 在初始时等价于恒等函数。这是 DiT 论文最终采用的方案。

核心结论

论文实验结果表明,四种条件注入方式的性能排序为:adaLN-Zero > adaLN > Cross-Attention > In-Context。adaLN-Zero 在所有模型尺寸上都取得了最优 FID。

2.5 Unpatchify:从 Tokens 还原 Latent

DiT 的最后一步是将 Transformer 输出的 token 序列还原为与输入相同尺寸的 latent:

  1. 最终 LayerNorm(由 adaLN-Zero 调制)
  2. 线性投影:将每个 token 从 \(d\) 维映射到 \(p \times p \times 2C\) 维(预测噪声和对角协方差)
  3. Reshape:将 token 序列重新排列为空间结构 \(H \times W \times 2C\)

3. adaLN-Zero 详解

adaLN-Zero 是 DiT 的关键创新,值得深入理解。

3.1 从标准 LayerNorm 到 adaLN-Zero

标准 LayerNorm

\[y = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\]

其中 \(\gamma\)\(\beta\) 是可学习参数,与输入无关。

Adaptive LayerNorm (adaLN)

\[\gamma, \beta = \text{MLP}(c)$$ $$y = \gamma(c) \cdot \frac{x - \mu}{\sigma} + \beta(c)\]

\(\gamma\)\(\beta\) 由条件向量 \(c\) 动态生成,实现了条件依赖的归一化。

adaLN-Zero

在 adaLN 基础上,每个残差连接前额外引入一个 scale 参数 \(\alpha\)

\[y = x + \alpha(c) \cdot \text{Block}(\text{adaLN}(x, c))\]

其中 \(\alpha\) 同样由条件 \(c\) 回归得到。

3.2 零初始化的关键

adaLN-Zero 的"Zero"体现在初始化策略上:

  • MLP 回归 \(\alpha\) 的最后一层权重初始化为 全零
  • 这意味着在训练初始时,\(\alpha = 0\)
  • 因此每个 DiT block 的输出为:\(y = x + 0 \cdot \text{Block}(\cdot) = x\)
adaLN-Zero 初始化时的等效行为:

输入 x ──────────────────────────────── 输出 x
       |                           ^
       v                           | (alpha = 0, 被屏蔽)
  +----------+    +----------+     |
  | adaLN    |--->| Attention |--x--+
  |          |    | or FFN    |  ^
  +----------+    +----------+  |
       ^                    alpha = 0
       |
  条件 c ──> MLP ──> (gamma, beta, alpha)
                      全部初始化为合理值
                      alpha 初始化为 0

为什么零初始化有效?

这一策略借鉴了 ResNet 中的零初始化残差思想。核心好处是:

  1. 训练稳定性:模型初始时等价于恒等映射,梯度可以无损地流过整个网络,避免深层网络的训练不稳定问题
  2. 渐进式学习:每个 block 从"什么都不做"开始,逐渐学习有意义的变换,训练过程更加平滑
  3. 可以堆叠更深:由于初始时不会破坏信号,即使堆叠很多 block 也不会导致训练崩溃

3.3 adaLN-Zero 的完整计算流程

一个 DiT Block 内部的完整计算如下:

DiT Block (adaLN-Zero) 内部结构
================================================================

条件向量 c ──> MLP ──> (gamma_1, beta_1, alpha_1,
                        gamma_2, beta_2, alpha_2)

输入 x
  |
  v
  adaLN(x, gamma_1, beta_1)
  |
  v
  Multi-Head Self-Attention
  |
  v
  x alpha_1 (逐元素乘以 scale)
  |
  v
  + x (残差连接) ──> x'
  |
  v
  adaLN(x', gamma_2, beta_2)
  |
  v
  Pointwise FeedForward (MLP)
  |
  v
  x alpha_2 (逐元素乘以 scale)
  |
  v
  + x' (残差连接) ──> 输出

================================================================

每个 block 从条件 \(c\) 回归出 6 个向量:attention 分支的 \((\gamma_1, \beta_1, \alpha_1)\) 和 FFN 分支的 \((\gamma_2, \beta_2, \alpha_2)\)


4. Latent DiT 工作流

DiT 与 Latent Diffusion Model (LDM) 的结合方式非常直接 -- 将 U-Net 替换为 DiT,其余保持不变。

4.1 训练流程

训练阶段
================================================================

真实图像 x_0         时间步 t        类别标签 y
    |                   |                |
    v                   |                |
+----------+            |                |
| VAE      |            v                v
| Encoder  |       sinusoidal +     Embedding
+----------+        MLP = t_emb     Table = y_emb
    |                   |                |
    v                   +-------+--------+
  Latent z_0                    |
    |                     c = t_emb + y_emb
    v                           |
  加噪声 (q(z_t|z_0))          |
  z_t = sqrt(a_t)*z_0          |
     + sqrt(1-a_t)*eps          |
    |                           |
    v                           v
  +-------------------------------+
  |          DiT(z_t, c)          |
  |   预测噪声 eps_theta 或 x_0  |
  +-------------------------------+
              |
              v
   Loss = ||eps - eps_theta||^2

================================================================

4.2 推理流程

推理阶段
================================================================

随机噪声 z_T ~ N(0, I)     条件信息 (t, y)
    |                           |
    v                           v
  +-------------------------------+
  |          DiT(z_t, c)          |    x N 步迭代
  |   预测噪声 eps_theta          |<---+
  +-------------------------------+    |
              |                        |
              v                        |
    去噪一步: z_{t-1} ----------------+
              |
              v (当 t=0)
         干净 latent z_0
              |
              v
         +----------+
         | VAE      |
         | Decoder  |
         +----------+
              |
              v
          生成图像 x_0

================================================================

4.3 与 Stable Diffusion 的对比

架构对比

DiT 与 Stable Diffusion 的区别仅在去噪网络部分:

组件 Stable Diffusion Latent DiT
图像编码器 VAE Encoder VAE Encoder(相同)
潜空间 32x32x4 latent 32x32x4 latent(相同)
去噪网络 U-Net DiT (Transformer)
噪声调度 DDPM / DDIM DDPM / DDIM(相同)
图像解码器 VAE Decoder VAE Decoder(相同)

可以说,DiT 是 Stable Diffusion 的"drop-in replacement" -- 换了一个去噪网络,其他一切不变。


5. Scaling 实验结果

DiT 论文的一大贡献是系统地研究了 Transformer-based Diffusion 模型的 scaling 行为。

5.1 模型配置

论文设计了四个不同大小的模型,命名规则为 DiT-{Size}/{Patch}:

模型 隐藏维度 \(d\) 层数 \(N\) Attention Heads 参数量 (p=2) Gflops (p=2)
DiT-S 384 12 6 33M 6.06
DiT-B 768 12 12 130M 23.0
DiT-L 1024 24 16 458M 80.7
DiT-XL 1152 28 16 675M 119

5.2 ImageNet 256x256 生成结果

在 class-conditional ImageNet 256x256 生成任务上(使用 classifier-free guidance):

模型 FID-50K IS
DiT-S/2 68.4 27.1
DiT-B/2 43.5 44.1
DiT-L/2 23.3 76.3
DiT-XL/2 9.62 121.5
DiT-XL/2 (长时间训练) 2.27 278.2

关键发现

  1. 模型越大,FID 越低:从 DiT-S 到 DiT-XL,FID 单调下降,且下降幅度显著
  2. Patch 越小,效果越好\(p = 2\) 优于 \(p = 4\) 优于 \(p = 8\),因为更小的 patch 保留更多空间信息
  3. Scaling 曲线平滑:FID 随计算量 (Gflops) 的增加呈现非常平滑的下降趋势
  4. adaLN-Zero 在所有尺寸上最优:条件注入方式的优劣排序在不同模型大小上保持一致

5.3 Scaling 曲线

FID-50K vs. 计算量 (Gflops)  [示意]
================================================================

  FID
   ^
70 |  * DiT-S/2
   |
60 |
   |
50 |
   |    * DiT-B/2
40 |
   |
30 |
   |        * DiT-L/2
20 |
   |
10 |              * DiT-XL/2
   |
 2 |                    * DiT-XL/2 (长训练)
   +-----+-----+------+------+------> Gflops
        6     23     81    119

================================================================

这条曲线是 DiT 论文最核心的结果之一:它表明 Diffusion Transformer 具有与 LLM 类似的 scaling law -- 可预测地通过增加计算获得性能提升

5.4 与其他模型的对比

DiT-XL/2 在充分训练后,FID 达到 2.27,与当时最好的 Diffusion 模型(如 ADM、LDM)相当或更优:

模型 FID-50K 骨干网络
ADM (Dhariwal & Nichol) 10.94 U-Net
ADM + Classifier Guidance 4.59 U-Net
LDM-4 (Rombach et al.) 10.56 U-Net
DiT-XL/2 2.27 Transformer

6. DiT 的后续影响

DiT 的出现标志着 Diffusion 模型从"U-Net 时代"向"Transformer 时代"的转变。其后续影响深远。

6.1 Sora (OpenAI, 2024)

OpenAI 的 Sora 视频生成模型据传基于 DiT 架构的扩展,将 DiT 从 2D 图像扩展到 3D 时空 (spacetime) patches。Sora 的技术报告中明确引用了 DiT,并描述了将视频视为"spacetime patches"序列的思路。

6.2 SD3 / FLUX: MM-DiT

Stable Diffusion 3 和 FLUX 系列采用了 MM-DiT (Multi-Modal DiT) 架构:

  • 双流 (dual-stream) 设计:图像 tokens 和文本 tokens 各有独立的 Transformer 流
  • 在 attention 层进行交互:两个流的 tokens 拼接后做 joint attention
  • 比原始 DiT 的 cross-attention 更加灵活

6.3 PixArt-alpha

PixArt-alpha 关注 DiT 的高效训练:

  • 使用预训练的 T5 文本编码器提供文本条件
  • 分阶段训练策略,降低训练成本
  • 证明 DiT 架构可以在较少计算资源下也能训练出高质量模型

6.4 SiT (Scalable Interpolant Transformers)

SiT 保留 DiT 的 Transformer 架构,但将底层的 DDPM 框架替换为 stochastic interpolant / flow matching 框架,获得了进一步的性能提升。

6.5 其他

  • Hunyuan-DiT (腾讯):中文理解能力强的文生图 DiT
  • Open-Sora:开源的基于 DiT 的视频生成方案
  • Latte:将 DiT 扩展到视频生成的早期工作

7. DiT vs U-Net 对比

维度 U-Net DiT (Transformer)
架构来源 医学图像分割 (2015) 视觉识别 ViT (2020)
核心结构 Encoder-Decoder + Skip Connection 堆叠 Transformer Blocks
Scaling 行为 不够清晰,难以预测 平滑且可预测,类似 LLM
条件注入 Cross-Attention + Time Embed adaLN-Zero(更高效)
多分辨率特征 天然支持(U 形结构) 需要额外设计
计算效率 中等 高(可利用 FlashAttention 等)
工程生态 成熟但特定 可复用 LLM 训练基础设施
参数效率 较低 较高(参数利用率更高)
视频扩展 需要 3D 卷积改造 自然扩展到时空 tokens
多模态扩展 困难 自然(token 拼接即可)
归纳偏置 强(局部性、多尺度) 弱(数据驱动)
小数据表现 较好 需要更多数据

趋势判断

从 2023 年开始,新提出的 SOTA 生成模型几乎全部转向 Transformer 架构。U-Net 在 Diffusion 领域的主导地位已经让位于 DiT 及其变体。这一趋势与 ViT 取代 CNN 成为视觉骨干的趋势一脉相承。


8. 思考与讨论

8.1 为什么 Transformer 能替代 U-Net?

表面上看,U-Net 的多尺度特征和 skip connection 似乎是图像生成的必要结构。但 DiT 的成功揭示了几个深层原因:

  1. Attention 可以学到多尺度特征:虽然 Transformer 没有显式的多尺度结构,但 self-attention 可以自适应地关注不同距离的 token,隐式地实现多尺度建模
  2. 足够的模型容量可以弥补归纳偏置的缺失:与 ViT vs CNN 的故事一致,当模型足够大、数据足够多时,强归纳偏置不再是必要条件
  3. Latent space 降低了分辨率需求:在 latent space(如 32x32)而非 pixel space(如 256x256)中工作,大幅降低了 token 数量,使 Transformer 的 \(O(n^2)\) 复杂度可以接受

8.2 DiT 对视频生成的意义

DiT 对视频生成的推动意义尤为深远:

  • 自然的时空扩展:将 2D patch 扩展为 3D spacetime patch,Transformer 天然支持处理变长序列
  • 统一的训练框架:同一架构可以同时处理图像和视频,只需调整 patch 的切割方式
  • Scaling 的可行性:视频生成需要更大的模型,而 DiT 的 scaling law 保证了加大模型是有回报的

Sora 的出现正是这一路线的最佳验证。

8.3 DiT 是否是"统一架构"的又一例证?

从更宏观的视角看,DiT 是 Transformer 作为"统一架构"征服又一个领域的案例:

Transformer 的领域扩张路线
================================================================

  2017  NLP (原始 Transformer)
    |
  2018  NLP 预训练 (BERT, GPT)
    |
  2020  视觉识别 (ViT)
    |
  2021  多模态理解 (CLIP)
    |
  2023  图像生成 (DiT)        <--- 我们在这里
    |
  2024  视频生成 (Sora)
    |
  2024+ 音频、3D、机器人控制...

================================================================

这一趋势背后的深层逻辑是:Transformer 的通用性和 scaling 能力使其成为"计算的通用基底"。不同领域的任务差异通过 tokenization 方式(文本 token、image patch、video patch、audio patch)来处理,而核心计算引擎保持一致。

值得警惕的问题

统一架构并不意味着最优架构。Transformer 的 \(O(n^2)\) attention 复杂度在超长序列(如高分辨率视频)上仍然是瓶颈。线性 attention、state space models (Mamba) 等替代方案仍在发展中。DiT 的成功不应被过度解读为"Transformer 永远是最优选择",而应理解为"在当前的计算范式和数据规模下,Transformer 是最具 scaling 潜力的架构"。


9. 总结

DiT 的核心贡献可以归纳为三点:

  1. 验证了 Transformer 可以替代 U-Net 作为 Diffusion 模型的去噪骨干,且性能相当或更优
  2. 提出了 adaLN-Zero 这一高效的条件注入方式,通过零初始化实现稳定训练
  3. 系统地研究了 scaling 行为,证明 Diffusion Transformer 具有与 LLM 类似的可预测 scaling law

DiT 的影响远超论文本身。它开启了生成式 AI 从"特定架构"向"统一架构"的转变,为 Sora、SD3 等后续工作奠定了基础,也为"Transformer 是通用计算架构"这一观点提供了又一个有力证据。


评论 #