Skip to content

Initialization

Initialization

方差传播的数学推导

本节基于 Glorot & Bengio (2010) “Understanding the difficulty of training deep feedforward neural networks” 和 He et al. (2015) “Delving Deep into Rectifiers” 的理论推导。

初始化的核心问题是:如何设置权重的初始分布,使得信号在前向传播和梯度在反向传播时既不爆炸也不消失?这需要对方差传播进行严格的数学分析。

DNN 模型设置

考虑一个 \(L\) 层的全连接网络,第 \(l\) 层有 \(m_l\) 个神经元。第 \(l\) 层的前向传播为:

\[ \mathbf{y}_l = \mathbf{W}_l \mathbf{x}_{l-1} + \mathbf{b}_l, \quad \mathbf{x}_l = f(\mathbf{y}_l) \]

其中 \(\mathbf{W}_l \in \mathbb{R}^{m_l \times m_{l-1}}\) 是权重矩阵,\(\mathbf{b}_l\) 是偏置向量,\(f(\cdot)\) 是激活函数,\(\mathbf{y}_l\) 是激活前的值(pre-activation),\(\mathbf{x}_l\) 是激活后的值。

关键假设:

  1. 权重 \(w_l^{(ij)}\) 独立同分布,均值为零:\(E[w] = 0\)
  2. 输入 \(x_{l-1}^{(j)}\) 独立同分布,且与权重独立
  3. 偏置初始化为零:\(b_l = 0\)

前向传播方差分析

对于第 \(l\) 层的第 \(i\) 个神经元:

\[ y_l^{(i)} = \sum_{j=1}^{m_{l-1}} w_l^{(ij)} x_{l-1}^{(j)} \]

由于 \(w\)\(x\) 相互独立且 \(E[w] = 0\),利用 \(\text{Var}(wx) = \text{Var}(w) \cdot E[x^2]\)(当 \(E[w]=0\) 时),求和后得到:

\[ \boxed{\text{Var}[y_l] = m_{l-1} \cdot \text{Var}[w_l] \cdot E[x_{l-1}^2]} \]

这个公式是所有初始化方法的基础。

对称激活函数(Sigmoid/Tanh)的情况:

当激活函数关于原点对称时(如 Tanh),输出均值为零,即 \(E[x] = 0\),因此 \(E[x^2] = \text{Var}(x)\)。此时:

\[ \text{Var}[y_l] = m_{l-1} \cdot \text{Var}[w_l] \cdot \text{Var}[x_{l-1}] = m_{l-1} \cdot \text{Var}[w_l] \cdot \text{Var}[y_{l-1}] \]

(最后一步在线性区域近似 \(x \approx y\) 时成立)

要让方差在层间保持不变(\(\text{Var}[y_l] = \text{Var}[y_{l-1}]\)),需要:

\[ \boxed{m_{l-1} \cdot \text{Var}[w_l] = 1} \quad \Longrightarrow \quad \text{Var}[w_l] = \frac{1}{m_{l-1}} \]

这就是 Xavier 条件

ReLU 激活函数的情况:

ReLU 将负半轴截断为零,因此 \(E[x] \neq 0\)。假设 \(y\) 关于零对称分布,ReLU 只保留正半部分,所以:

\[ E[x^2] = E[\text{ReLU}(y)^2] = \frac{1}{2} E[y^2] = \frac{1}{2} \text{Var}(y) \]

代入方差传播公式:

\[ \text{Var}[y_l] = m_{l-1} \cdot \text{Var}[w_l] \cdot \frac{1}{2}\text{Var}[y_{l-1}] \]

要让方差保持不变,需要:

\[ \boxed{\frac{1}{2} \cdot m_{l-1} \cdot \text{Var}[w_l] = 1} \quad \Longrightarrow \quad \text{Var}[w_l] = \frac{2}{m_{l-1}} \]

这就是 Kaiming/He 条件——额外的因子 2 正是为了补偿 ReLU 丢弃负半轴带来的方差减半。

反向传播梯度方差分析

类似地,可以分析梯度在反向传播时的方差。设损失对第 \(l\) 层激活前值的梯度为 \(\delta_l = \frac{\partial L}{\partial y_l}\),则:

\[ \delta_{l-1} = f'(y_{l-1}) \odot (\mathbf{W}_l^T \delta_l) \]

对梯度方差做类似分析:

\[ \text{Var}[\delta_{l-1}] = m_l \cdot \text{Var}[w_l] \cdot \text{Var}[\delta_l] \]

要让梯度方差保持不变,需要 \(m_l \cdot \text{Var}[w_l] = 1\),即 \(\text{Var}[w_l] = \frac{1}{m_l}\)

注意前向条件要求 \(\text{Var}[w] = \frac{1}{m_{l-1}}\)(fan-in),反向条件要求 \(\text{Var}[w] = \frac{1}{m_l}\)(fan-out)。这两个条件一般不能同时满足(除非 \(m_{l-1} = m_l\)),因此 Xavier 取折中:

\[ \text{Var}[w_l] = \frac{2}{m_{l-1} + m_l} = \frac{2}{n_{\text{in}} + n_{\text{out}}} \]

均匀分布的参数推导

对于均匀分布 \(U(-a, a)\),其方差为 \(\text{Var} = \frac{a^2}{3}\)

Xavier 均匀分布:令 \(\frac{a^2}{3} = \frac{2}{n_{\text{in}} + n_{\text{out}}}\),解得:

\[ a = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}} \]

Kaiming 均匀分布:令 \(\frac{a^2}{3} = \frac{2}{n_{\text{in}}}\),解得:

\[ a = \sqrt{\frac{6}{n_{\text{in}}}} \]

Initialization Strategies

初始化策略决定了你需要初始化的随机数的范围标准差。如果用一个装随机数的盒子来类比的话,初始化策略决定了盒子的大小和边界。

Xavier/Glorot Initialization

Glorot & Bengio, “Understanding the difficulty of training deep feedforward neural networks”, AISTATS 2010.

Xavier Initialization的核心思想是保持网络中输入和输出信号的方差相同,以确保信号在整个网络中不会发散或衰减。

最初主要用于Sigmoid或Tanh这种对称激活函数,因为它们在大值和小值区域接近线性,且输出均值接近于零。Xavier 条件来自上面推导的 \(m \cdot \text{Var}[w] = 1\)(前向)和折中前向/反向条件。

正态分布形式:

\[ W \sim \mathcal{N}\left(0, \; \frac{2}{n_{\text{in}} + n_{\text{out}}}\right) \]

均匀分布形式:

权重 \(W\) 从一个均匀分布 \(U(-a, a)\) 中采样,其中:

\[ a = \sqrt{\frac{6}{n_{in} + n_{out}}} \]
  • \(n_{in}\):当前层的输入神经元数量(扇入数,fan-in)。
  • \(n_{out}\):当前层的输出神经元数量(扇出数,fan-out)。

举例来说,假设在代码里定义了一个 Linear(in_features=100, out_features=100) 的层:

计算 \(a\)**** :

\[ a = \sqrt{\frac{6}{100 + 100}} = \sqrt{0.03} \approx 0.173 \]

执行分布 :该层所有的权重 \(W\) 都会从 \(\text{Uniform}(-0.173, 0.173)\) 中随机采样。

  • 比如 \(W_1 = 0.05\)
  • \(W_2 = -0.12\)
  • \(W_3 = 0.17\)
  • ...但绝对不会出现 \(0.2\),因为它超过了 \(a\)

Kaiming/He Initialization

He et al., “Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification”, ICCV 2015.

由Kaiming He等人提出,专门为 ReLU(Rectified Linear Unit) 激活函数设计。由于 ReLU 在负半轴的输出为零(即”杀死”了一半的神经元),它会导致信号的方差在每层传播后减半。Kaiming 初始化通过设置 \(\text{Var}[w] = \frac{2}{n_{\text{in}}}\) 来补偿 ReLU 带来的这种损失(对比 Xavier 的 \(\frac{1}{n_{\text{in}}}\),多出的因子 2 正好抵消 ReLU 的半方差效应)。

适用场景主要是ReLU及其变体(如Leaky ReLU)。

正态分布形式:

\[ W \sim \mathcal{N}\left(0, \; \frac{2}{n_{\text{in}}}\right) \]

均匀分布形式:

权重 \(W\) 从一个均匀分布 \(U(-a, a)\) 中采样,其中:

\[ a = \sqrt{\frac{6}{n_{in}}} \]
  • 注意: 公式中只使用了输入神经元数量 \(n_{in}\)(fan-in mode)。Kaiming 初始化也可以使用 fan-out mode(\(\text{Var}[w] = \frac{2}{n_{\text{out}}}\)),这在某些卷积层中更合适。

Initialization Distributions

初始化权重策略决定了初始化随机数产生的形状。如果用装随机数的盒子来类比的话,分布决定了装入盒子的随机数的形状。

Uniform Distribution

均匀分布,权重值 \(W\) 在一个特定的区间 \([-\mathbf{A}, +\mathbf{A}]\)等概率地随机选取。权重值均匀地分布在设定的范围内。Xavier 和 Kaiming 的公式会计算出这个区间半宽 \(A\) 的值。

Normal Distribution

正态分布,权重值 \(W\) 从一个均值为 \(\mu\)、标准差为 \(\sigma\)* 的高斯(正态)分布中随机选取。靠近均值 (*\(\mu=0\)) 的权重值被选中的概率最高,离均值越远,被选中的概率越低。

在初始化时,均值通常设为 \(\mu=0\)。Xavier 和 Kaiming 的公式会用于计算标准差 \(\sigma\) 的值(例如 $ \sigma^2 = \frac{2}{n_{in} + n_{out}}$ 或 $ \sigma^2 = \frac{2}{n_{in}}$)。




为什么不能随便初始化权重

对称性破缺(Symmetry Breaking)

初始化的一个基本要求是打破对称性。如果同一层中所有神经元的权重完全相同,那么它们会计算出完全相同的输出,接收到完全相同的梯度,并进行完全相同的更新——这意味着无论训练多久,这些神经元都是彼此的"克隆体",整层退化为只有一个有效神经元。

随机初始化的本质目的就是让每个神经元从不同的起点出发,学习到不同的特征。即使是很小的随机扰动就足以破坏对称性,但权重的尺度必须精心选择(即 Xavier/Kaiming 条件),否则会出现梯度爆炸或消失。

随便初始化的后果

我们必须避免几种会导致显著恶果的初始化:

  • 全0初始化:无法打破对称性,网络退化为一个极简单的线性模型
  • 相同常数初始化:出现和全0初始化一样的问题——所有神经元永远相同
  • 全正初始化:均值偏移、饱和、锯齿现象
  • 较大的初始化:会导致梯度爆炸或神经元死亡(Sigmoid/Tanh 进入饱和区)
  • 较小的初始化:会导致梯度消失(信号逐层衰减至零)

其他初始化方法

正交初始化 (Orthogonal Initialization)

正交初始化将权重矩阵初始化为正交矩阵(或其近似),使得矩阵的所有奇异值都为 1。这保证了信号在前向传播和反向传播时既不放大也不缩小。

方法: 先从标准正态分布中采样一个矩阵,然后对其做 QR 分解或 SVD 分解,取正交部分作为权重矩阵。

\[ W = Q, \quad \text{其中 } A = QR, \; A_{ij} \sim \mathcal{N}(0, 1) \]

适用场景: 正交初始化特别适合 RNN/LSTM 中的循环权重矩阵。在 RNN 中,隐状态需要反复乘以同一个权重矩阵,如果这个矩阵的谱范数偏离 1,就会导致梯度在时间步上指数级地爆炸或消失。正交矩阵的谱范数恰好为 1,因此能有效缓解这个问题。

# PyTorch 中的正交初始化
import torch.nn as nn

nn.init.orthogonal_(layer.weight)

Scaling Initialization / LSUV (Layer-Sequential Unit-Variance)

Mishkin & Matas, "All you need is a good init", ICLR 2016.

Scaling Initialization(也称 LSUV)是一种数据驱动的初始化方法。其核心思想是:理论推导(如 Xavier/Kaiming)依赖于对激活函数和数据分布的假设,而这些假设在实际深度网络中可能不完全成立。LSUV 直接通过实际数据来校准每一层的输出方差:

  1. 用正交初始化对所有层做初始化(正交矩阵作为起点)
  2. 将一个 mini-batch 的数据输入网络
  3. 从第一层开始,逐层检查每层输出的方差
  4. 如果方差不为 1,则缩放该层的权重使输出方差为 1:\(W_l \leftarrow W_l / \sqrt{\text{Var}[\mathbf{x}_l]}\)
  5. 重复直到所有层的输出方差都接近 1

优点: 不依赖对激活函数的假设,适用于任意网络结构和激活函数(包括非标准激活如 Mish 等)。实验表明 LSUV 在 GoogLeNet、VGG 等网络上能达到与 BN 相当的效果,而不需要在网络中添加额外的 BN 层。

预训练初始化 (Pre-trained Initialization)

在迁移学习和微调场景中,使用预训练模型的权重作为初始化是最常见也最有效的做法:

  • ImageNet 预训练:对于视觉任务,通常使用 ImageNet 预训练的 ResNet/ViT 权重作为初始化
  • 大语言模型:LLM 的微调本质上就是以预训练权重为初始化,在特定任务数据上继续训练
  • 新增层的初始化:在预训练模型上添加新的分类头或适配层时,新增层通常用 Xavier 或 Kaiming 初始化,而预训练层保持原始权重

预训练初始化的效果通常远好于随机初始化,因为预训练权重已经编码了丰富的特征表示。


PyTorch 中的初始化实践

常用初始化函数

import torch.nn as nn

# Xavier 初始化
nn.init.xavier_uniform_(layer.weight)   # 均匀分布
nn.init.xavier_normal_(layer.weight)    # 正态分布

# Kaiming 初始化
nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')

# 正交初始化
nn.init.orthogonal_(layer.weight)

# 常数初始化(通常用于 bias)
nn.init.zeros_(layer.bias)
nn.init.constant_(layer.bias, 0.01)

对整个模型进行初始化

def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm)):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)

model.apply(init_weights)

初始化方法选择建议

激活函数 推荐初始化 原因
Sigmoid / Tanh Xavier 保持输入输出方差一致
ReLU / Leaky ReLU Kaiming 补偿 ReLU 的方差减半效应
GELU / SiLU Kaiming 与 ReLU 类似的非对称特性
RNN 循环权重 正交初始化 保持时间步上的梯度稳定
Transformer Xavier(原始)/ 缩放初始化 深层 Transformer 常对残差分支缩放 \(1/\sqrt{2N}\)
微调场景 预训练权重 已包含丰富特征表示

深层网络的缩放初始化: 在非常深的网络(如 GPT 等深层 Transformer)中,残差连接会导致信号在层数增多时逐渐增大。一种常见做法是对残差分支的最后一层权重乘以 \(1/\sqrt{2N}\)\(N\) 为层数),使得信号在残差路径上的增长受到控制。GPT-2 的论文中就采用了这种做法。


评论 #