Skip to content

Knowledge Distillation

Overview

Knowledge Distillation transfers knowledge from a large model (teacher) to a smaller model (student), playing a critical role in model compression and deployment optimization. From Hinton's classic method to LLM-era distillation strategies, this chapter systematically covers the theory and practice of knowledge distillation.


1. Classic Knowledge Distillation

1.1 Hinton Distillation (2015)

Core Idea: The teacher model's soft labels contain "dark knowledge" about inter-class relationships.

Distillation Loss:

\[ \mathcal{L} = \alpha \mathcal{L}_{\text{CE}}(y, \sigma(z_s)) + (1 - \alpha) T^2 D_{KL}\left(\sigma(z_t / T) \| \sigma(z_s / T)\right) \]

where:

  • \(z_s, z_t\): Student and teacher logits
  • \(\sigma\): Softmax function
  • \(T\): Temperature parameter (typically \(T=3\sim20\))
  • \(\alpha\): Weight between hard and soft label losses
  • \(T^2\) factor: Compensates for the gradient scaling effect of temperature

Role of Temperature:

\[ \sigma(z_i / T) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} \]
  • \(T=1\): Standard softmax
  • \(T \to \infty\): Uniform distribution
  • Higher \(T\) produces smoother distributions, revealing more inter-class relationships

1.2 Why Are Soft Labels Effective?

A teacher output of \([0.7, 0.2, 0.1]\) contains more information than the hard label \([1, 0, 0]\):

  • Class 2 is more similar to class 1 than class 3
  • This similarity structure is valuable knowledge
  • The student learns not just the "correct answer" but relative relationships between classes

2. Feature Distillation

2.1 FitNets

FitNets (Romero et al., 2015): Distill intermediate layer features, not just outputs.

\[ \mathcal{L}_{\text{FitNet}} = \| W_s \cdot F_s - F_t \|_2^2 \]

where \(F_s, F_t\) are student and teacher intermediate features, and \(W_s\) is a projection matrix for dimension matching.

Advantage: Students learn better intermediate representations.

2.2 Attention Distillation

Attention Transfer (Zagoruyko & Komodakis, 2017): Distill attention maps.

\[ \mathcal{L}_{\text{AT}} = \sum_l \left\| \frac{A_s^l}{\|A_s^l\|_2} - \frac{A_t^l}{\|A_t^l\|_2} \right\|_2^2 \]

where the attention map \(A^l = \sum_c |F^l_c|^2\) (squared activations summed along channel dimension).

2.3 Relational Distillation

RKD (Relational Knowledge Distillation): Distill inter-sample relationships.

\[ \mathcal{L}_{\text{RKD}} = \sum_{(i,j)} \left(\psi_t(t_i, t_j) - \psi_s(s_i, s_j)\right)^2 \]

where \(\psi\) can represent distance or angular relationships.

2.4 Distillation Methods Summary

Method What Is Distilled Level
Hinton KD Soft logits Output layer
FitNets Intermediate features Hidden layers
Attention Transfer Attention maps Attention layers
RKD Inter-sample relations Representation space
CRD Contrastive representations Representation space
PKD Patient distillation (multi-layer) Multiple layers

3. Knowledge Distillation in NLP

3.1 DistilBERT

DistilBERT (Sanh et al., 2019): A distilled version of BERT.

Design:

  • 6 layers (original BERT has 12), 40% fewer parameters
  • Distillation loss = soft labels + hard labels + cosine similarity
\[ \mathcal{L} = \alpha \mathcal{L}_{\text{CE}} + \beta \mathcal{L}_{\text{KD}} + \gamma \mathcal{L}_{\text{cos}} \]

Result: Retains 97% of BERT's performance, 60% faster.

3.2 TinyBERT

TinyBERT (Jiao et al., 2020): More comprehensive BERT distillation.

Four-Layer Distillation:

  1. Embedding layer: \(\mathcal{L}_{\text{emb}} = \text{MSE}(E_s W_e, E_t)\)
  2. Attention layer: \(\mathcal{L}_{\text{attn}} = \text{MSE}(A_s, A_t)\)
  3. Hidden layer: \(\mathcal{L}_{\text{hid}} = \text{MSE}(H_s W_h, H_t)\)
  4. Prediction layer: \(\mathcal{L}_{\text{pred}} = \text{KD}\)

Two Stages: General distillation (pretraining phase) + task distillation (fine-tuning phase)

3.3 MiniLM

MiniLM (Wang et al., 2020): Distills Q-K relationships and V-V relationships in self-attention.

\[ \mathcal{L} = D_{KL}(A_t^L \| A_s^l) + D_{KL}(V_t^L {V_t^L}^\top \| V_s^l {V_s^l}^\top) \]

4. Knowledge Distillation in the LLM Era

4.1 LLM Distillation Challenges

Challenge Description
Scale gap Teachers are often 100B+, students typically 1-7B
White-box/Black-box Many API models don't expose logits
Task diversity LLMs are general-purpose, not single-task
Emergent abilities Small models may not reproduce emergent behaviors

4.2 White-Box Distillation

When teacher logits are accessible:

Standard Approach:

\[ \mathcal{L} = \lambda \mathcal{L}_{\text{KD}}(p_t, p_s) + (1-\lambda) \mathcal{L}_{\text{NTP}}(y, p_s) \]

MiniLLM: Replaces forward KL with reverse KL:

\[ \mathcal{L}_{\text{MiniLLM}} = D_{KL}(p_s \| p_t) \]

Reverse KL encourages the student to concentrate on high-probability regions of the teacher (mode-seeking), avoiding over-dispersion.

4.3 Black-Box Distillation

When only the teacher's text outputs are available:

  • Data generation: Use teacher to generate high-quality training data
  • Alpaca/Vicuna approach: Use GPT-4 to generate instruction data for training smaller models
  • Self-Instruct: Use teacher to automatically generate instruction-response pairs

4.4 Representative Works

Method Teacher Student Strategy
Alpaca text-davinci-003 LLaMA-7B Black-box data distillation
Vicuna GPT-4 LLaMA-13B Black-box dialogue distillation
Orca GPT-4 LLaMA-13B Explanation-augmented distillation
MiniLLM GPT-2 XL GPT-2 White-box reverse KL
GKD PaLM - Online distillation

5. Self-Distillation

5.1 Concept

The model serves as its own teacher:

  • Born-Again Networks: Train multiple generations, each using the previous as teacher
  • DINO/DINOv2: EMA teacher + student self-distillation
  • Noisy Student: Self-training + noise

5.2 Self-Distillation in LLMs

  • Self-Improve: Model generates data to train itself
  • STaR: Self-Taught Reasoner
  • Constitutional AI: Self-critique and improvement

6. Practical Guide

6.1 Distillation Hyperparameters

Hyperparameter Recommended Notes
Temperature \(T\) 3-20 Higher = smoother
\(\alpha\) 0.5-0.9 Soft label weight
Student depth 1/2-2/3 of teacher Too shallow hurts quality
Student width 1/2-3/4 of teacher Width matters more than depth

6.2 Practical Insights

  • Teacher-student gap should not be too large (otherwise hard to learn)
  • Multi-step distillation (Teacher → TA → Student) can work better
  • Data quality matters more than quantity
  • Feature distillation often outperforms logits distillation

7. Summary

graph TD
    A[Knowledge Distillation] --> B[Logits Distillation]
    A --> C[Feature Distillation]
    A --> D[Relational Distillation]
    A --> E[Data Distillation]

    B --> B1[Hinton KD]
    B --> B2[MiniLLM]

    C --> C1[FitNets]
    C --> C2[TinyBERT]

    D --> D1[RKD]
    D --> D2[CRD]

    E --> E1[Alpaca]
    E --> E2[Vicuna]

Key Takeaways:

  1. Classic distillation leverages dark knowledge in soft labels
  2. Feature distillation transfers intermediate representations
  3. LLM distillation faces scale and black-box challenges
  4. Black-box distillation (data generation) is mainstream in the LLM era
  5. Self-distillation is a growing research direction

References

  • Hinton et al., "Distilling the Knowledge in a Neural Network," 2015
  • Romero et al., "FitNets: Hints for Thin Deep Nets," ICLR 2015
  • Sanh et al., "DistilBERT, a distilled version of BERT," NeurIPS Workshop 2019
  • Jiao et al., "TinyBERT: Distilling BERT for Natural Language Understanding," EMNLP 2020
  • Gu et al., "MiniLLM: Knowledge Distillation of Large Language Models," ICLR 2024

评论 #