Skip to content

GRU

The Gated Recurrent Unit (GRU), proposed by Cho et al. in 2014, is a simplified variant of LSTM. It reduces LSTM's three gates to two and eliminates the separate cell state, achieving comparable performance with fewer parameters. It has become another important gated architecture in the RNN family.


Background and Motivation

Problems with LSTM

LSTM successfully addresses the vanishing gradient problem by introducing the forget gate, input gate, output gate, and an independent cell state \(C_t\). However, this comes at a cost:

  • Large number of parameters: 4 sets of weight matrices (forget gate, input gate, candidate value, output gate), resulting in roughly 4x the parameters of a Vanilla RNN
  • High computational overhead: Each time step requires 4 matrix multiplications plus multiple activation function evaluations
  • Structural complexity: Two states (\(h_t\) and \(C_t\)) must be maintained simultaneously, making the model harder to understand and debug

The Core Question

Are all three gates and the independent cell state in LSTM truly necessary? Can a simpler structure achieve similar results?

The Core Idea Behind GRU

Cho et al. (2014) proposed GRU in the paper "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation", with the following key simplifications:

  1. Merge the forget gate and input gate into a single Update Gate \(z_t\): Use \(z_t\) to simultaneously control "how much old information to forget" and "how much new information to accept" in a complementary fashion (one control variable governs two things)
  2. Remove the independent cell state: Keep only the hidden state \(h_t\), letting it serve directly as the information carrier
  3. Introduce a Reset Gate \(r_t\): Controls how much historical information is referenced when generating the candidate hidden state

Result: 2 gates + 1 state, approximately 25% fewer parameters than LSTM, faster training, and comparable performance.


GRU Architecture in Detail

Overall Structure Diagram

            ┌─────────────────────────────────────────────────┐
            │                   GRU Cell                       │
            │                                                  │
  h_{t-1} ──┼──┬──────────────────┬───────────────┐            │
            │  │                  │               │            │
  x_t ──────┼──┤                  │               │            │
            │  │                  │               │            │
            │  ▼                  ▼               │            │
            │ ┌──────┐      ┌──────┐              │            │
            │ │Reset │      │Update│              │            │
            │ │ Gate │      │ Gate │              │            │
            │ │ r_t  │      │ z_t  │              │            │
            │ │  σ   │      │  σ   │              │            │
            │ └──┬───┘      └──┬───┘              │            │
            │    │             │                  │            │
            │    ▼             │                  │            │
            │  r_t ⊙ h_{t-1}  │                  │            │
            │    │             │                  │            │
            │    ▼             │                  │            │
            │ ┌────────┐      │                  │            │
            │ │Candidate│     │                  │            │
            │ │  h~_t   │     │                  │            │
            │ │  tanh   │     │                  │            │
            │ └───┬────┘      │                  │            │
            │     │           │                  │            │
            │     │           ▼                  ▼            │
            │     │     (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h~_t    │
            │     │           │                               │
            │     │           ▼                               │
            │     │         h_t ──────────────────────────────┼──→ h_t
            │     │                                           │
            └─────┼───────────────────────────────────────────┘
                  │

The most significant difference from LSTM: there is no independent cell state \(C_t\) — only a single hidden state \(h_t\) that handles all responsibilities.

The Two Gates

Reset Gate

\[ r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \]
  • Output range: \((0, 1)\), applied element-wise
  • Role: Determines how much historical information to "ignore" when generating the candidate hidden state
  • \(r_t \approx 1\): Fully preserve historical information (the candidate state references the complete \(h_{t-1}\))
  • \(r_t \approx 0\): Ignore historical information (the candidate state is determined almost entirely by the current input \(x_t\))

Intuitive Understanding

The reset gate allows GRU to "forget" irrelevant history. For example, when processing text, upon encountering a period that starts a new sentence, the reset gate can output values close to 0, letting the model "start fresh" without interference from the previous sentence.

Update Gate

\[ z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \]
  • Output range: \((0, 1)\), applied element-wise
  • Role: Determines the balance of "how much old state to keep vs. how much new state to accept" in the final hidden state
  • \(z_t \approx 1\): Almost entirely adopt the candidate new state (accept new information)
  • \(z_t \approx 0\): Almost entirely preserve the old state (keep memory unchanged)

Note on Sign Convention

The meaning of \(z_t\) may be reversed in different references. In this document, the convention is: the larger \(z_t\) is, the more the model favors the new candidate state. Some references define \(z_t\) large as retaining the old state. Pay attention to this distinction when reading papers.

Candidate Hidden State

\[ \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) \]

The key here is \(r_t \odot h_{t-1}\): the reset gate first "filters" the historical hidden state, which is then concatenated with the current input \(x_t\), passed through a linear transformation and \(\tanh\) activation to produce the candidate state.

Final Hidden State Update

\[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

This is the most central equation of GRU. The update gate \(z_t\) simultaneously controls two things:

  • \((1 - z_t) \odot h_{t-1}\): the proportion of the old state retained
  • \(z_t \odot \tilde{h}_t\): the proportion of the new state accepted

The weights of these two terms sum to exactly 1 (complementary), forming a convex combination that ensures numerical stability.

Key Difference from LSTM

In LSTM, the forget gate \(f_t\) and input gate \(i_t\) are independent — they can both be fully open simultaneously (\(f_t \approx 1\) and \(i_t \approx 1\)), meaning LSTM can simultaneously retain old information and write in large amounts of new information. In GRU, \((1-z_t)\) and \(z_t\) are complementary, forcing a trade-off between retention and update. This represents a minor loss of expressiveness due to GRU's simplification.


Intuitive Understanding of the Gating Mechanism

Update Gate \(z_t\): A Fusion of Two LSTM Gates

In LSTM, the forget gate \(f_t\) and input gate \(i_t\) separately control "discarding old information" and "writing new information." GRU's update gate \(z_t\) controls both with a single parameter:

  • LSTM: \(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\) (\(f_t\) and \(i_t\) are independent)
  • GRU: \(h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\) (complementary constraint)

Reset Gate \(r_t\): Flexible History Forgetting

The reset gate operates during the candidate state generation phase, not during the final state update. When \(r_t \approx 0\):

\[ \tilde{h}_t = \tanh(W_h \cdot [\mathbf{0}, x_t] + b_h) \]

The candidate state makes no reference to history at all — essentially a "fresh start." This enables GRU to completely switch context when needed.

Correspondence with LSTM Gates

GRU LSTM Function
Update gate \(z_t\) Forget gate \(f_t\) + Input gate \(i_t\) Controls the ratio of information retention vs. update
Reset gate \(r_t\) (No direct counterpart) Controls how much the candidate state depends on history
(None) Output gate \(o_t\) GRU outputs \(h_t\) directly without additional filtering
Hidden state \(h_t\) Cell state \(C_t\) + Hidden state \(h_t\) GRU has only one state serving both roles

Forward Pass: A Complete Numerical Example

Setup: input_size = 2, hidden_size = 3, processing 2 time steps.

Initialization

Assume weight matrices (each \(W \in \mathbb{R}^{3 \times 5}\), since \([h_{t-1}, x_t] \in \mathbb{R}^{3+2=5}\)):

\[ W_z = \begin{bmatrix} 0.1 & 0.2 & -0.1 & 0.3 & 0.1 \\ -0.2 & 0.1 & 0.3 & -0.1 & 0.2 \\ 0.3 & -0.1 & 0.2 & 0.1 & -0.2 \end{bmatrix}, \quad b_z = [0, 0, 0] \]
\[ W_r = \begin{bmatrix} 0.2 & -0.1 & 0.1 & 0.2 & -0.1 \\ 0.1 & 0.3 & -0.2 & 0.1 & 0.3 \\ -0.1 & 0.2 & 0.1 & -0.1 & 0.2 \end{bmatrix}, \quad b_r = [0, 0, 0] \]
\[ W_h = \begin{bmatrix} -0.1 & 0.2 & 0.3 & 0.1 & -0.2 \\ 0.2 & -0.1 & 0.1 & 0.3 & 0.1 \\ 0.1 & 0.1 & -0.2 & -0.1 & 0.3 \end{bmatrix}, \quad b_h = [0, 0, 0] \]

Initial hidden state \(h_0 = [0, 0, 0]\).

Time Step 1: \(x_1 = [1.0, 0.5]\)

Concatenated input: \([h_0, x_1] = [0, 0, 0, 1.0, 0.5]\)

Compute the update gate:

\[ z_1 = \sigma(W_z \cdot [h_0, x_1] + b_z) = \sigma([0.3 \times 1.0 + 0.1 \times 0.5,\ -0.1 \times 1.0 + 0.2 \times 0.5,\ 0.1 \times 1.0 + (-0.2) \times 0.5]) \]
\[ = \sigma([0.35, 0.0, 0.0]) = [0.587, 0.500, 0.500] \]

Compute the reset gate:

\[ r_1 = \sigma(W_r \cdot [h_0, x_1] + b_r) = \sigma([0.2 \times 1.0 + (-0.1) \times 0.5,\ 0.1 \times 1.0 + 0.3 \times 0.5,\ -0.1 \times 1.0 + 0.2 \times 0.5]) \]
\[ = \sigma([0.15, 0.25, 0.0]) = [0.537, 0.562, 0.500] \]

Compute the candidate hidden state (since \(h_0 = \mathbf{0}\), \(r_1 \odot h_0 = \mathbf{0}\)):

\[ \tilde{h}_1 = \tanh(W_h \cdot [\mathbf{0}, x_1] + b_h) = \tanh([0.1 \times 1.0 + (-0.2) \times 0.5,\ 0.3 \times 1.0 + 0.1 \times 0.5,\ -0.1 \times 1.0 + 0.3 \times 0.5]) \]
\[ = \tanh([0.0, 0.35, 0.05]) = [0.0, 0.336, 0.050] \]

Compute the final hidden state:

\[ h_1 = (1 - z_1) \odot h_0 + z_1 \odot \tilde{h}_1 = [0, 0, 0] + [0.587, 0.500, 0.500] \odot [0.0, 0.336, 0.050] \]
\[ = [0.0, 0.168, 0.025] \]

Time Step 2: \(x_2 = [-0.5, 0.8]\)

Concatenated input: \([h_1, x_2] = [0.0, 0.168, 0.025, -0.5, 0.8]\)

Compute the update gate:

\[ z_2 = \sigma(W_z \cdot [h_1, x_2] + b_z) \]
\[ = \sigma([0.0 + 0.034 - 0.003 - 0.15 + 0.08,\ 0.0 + 0.017 + 0.008 + 0.05 + 0.16,\ 0.0 - 0.017 + 0.005 - 0.05 - 0.16]) \]
\[ = \sigma([-0.039, 0.235, -0.222]) = [0.490, 0.559, 0.445] \]

Compute the reset gate:

\[ r_2 = \sigma(W_r \cdot [h_1, x_2] + b_r) \]
\[ = \sigma([0.0 - 0.017 + 0.003 - 0.1 - 0.08,\ 0.0 + 0.050 - 0.005 - 0.05 + 0.24,\ 0.0 + 0.034 + 0.003 + 0.05 + 0.16]) \]
\[ = \sigma([-0.194, 0.235, 0.247]) = [0.452, 0.559, 0.561] \]

Compute the candidate hidden state (\(r_2 \odot h_1 = [0.452 \times 0.0,\ 0.559 \times 0.168,\ 0.561 \times 0.025] = [0.0, 0.094, 0.014]\)):

\[ \tilde{h}_2 = \tanh(W_h \cdot [0.0, 0.094, 0.014, -0.5, 0.8] + b_h) \]
\[ = \tanh([0.0 + 0.019 + 0.004 - 0.05 - 0.16,\ 0.0 - 0.009 + 0.001 - 0.15 + 0.08,\ 0.0 + 0.009 - 0.003 + 0.05 + 0.24]) \]
\[ = \tanh([-0.187, -0.078, 0.296]) = [-0.185, -0.078, 0.287] \]

Compute the final hidden state:

\[ h_2 = (1 - z_2) \odot h_1 + z_2 \odot \tilde{h}_2 \]
\[ = [0.510 \times 0.0,\ 0.441 \times 0.168,\ 0.555 \times 0.025] + [0.490 \times (-0.185),\ 0.559 \times (-0.078),\ 0.445 \times 0.287] \]
\[ = [0.0, 0.074, 0.014] + [-0.091, -0.044, 0.128] = [-0.091, 0.030, 0.142] \]

Observations

Even in this simple 2-step example, we can see the GRU state being continuously updated: the update gate determines the mixing ratio of old and new information, while the reset gate affects how much the candidate state depends on history. The entire process requires only 3 sets of matrix multiplications (compared to LSTM's 4), making the computation noticeably more streamlined.


Backpropagation Through Time (BPTT) and Gradient Flow

The Linear Gradient Path Through the Update Gate

The GRU state update equation:

\[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

When computing \(\frac{\partial h_t}{\partial h_{t-1}}\), since \(h_{t-1}\) appears in multiple places (the update gate, reset gate, and candidate state all depend on it), the full gradient expression is quite complex. However, the most important gradient path is the direct path:

\[ \frac{\partial h_t}{\partial h_{t-1}} \bigg|_{\text{direct}} = \text{diag}(1 - z_t) \]

This path comes from the \((1 - z_t) \odot h_{t-1}\) term — the gradient only needs to be multiplied by \((1 - z_t)\), without passing through any matrix multiplication or squashing activation function.

Why GRU Also Mitigates Vanishing Gradients

By analogy with LSTM's cell state "highway":

LSTM GRU
Information highway \(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\) \(h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\)
Direct gradient path \(\frac{\partial C_t}{\partial C_{t-1}} = f_t\) \(\frac{\partial h_t}{\partial h_{t-1}}\big\|_{\text{direct}} = 1 - z_t\)
Gradient across \(T\) steps \(\prod_{t} f_t\) \(\prod_{t} (1 - z_t)\)
Condition for preserving gradients \(f_t \approx 1\) (forget gate close to 1) \(z_t \approx 0\) (update gate close to 0)

When the model learns that certain information needs to be retained long-term:

  • LSTM: Forget gate \(f_t \approx 1\), information passes through the cell state nearly unchanged
  • GRU: Update gate \(z_t \approx 0\), hidden state \(h_t \approx h_{t-1}\), information "passes through" directly

The two mechanisms are essentially equivalent: both provide an additive, controllable gradient propagation path, avoiding the problem in Vanilla RNNs where gradients must pass through the \(W_{hh}\) matrix multiplication.

Additive Structure Is the Key

Whether in LSTM or GRU, the core mechanism for mitigating vanishing gradients is the additive structure: new state = weighted old state + weighted new information. The gradient of addition is simply the coefficient, avoiding the eigenvalue issues associated with matrix multiplication. This idea is entirely consistent with the residual connection in ResNet: \(y = x + F(x)\).


GRU vs. LSTM: A Detailed Comparison

Dimension GRU LSTM
Year proposed 2014 (Cho et al.) 1997 (Hochreiter & Schmidhuber)
Number of gates 2 (update gate, reset gate) 3 (forget gate, input gate, output gate)
State variables Only \(h_t\) \(h_t\) and \(C_t\)
Parameter count \(3 d_h(d_h + d) + 3d_h\) \(4 d_h(d_h + d) + 4d_h\)
Parameter ratio Approximately 75% of LSTM Baseline (100%)
Training speed Faster (one fewer set of matrix operations) Slower
Information retention/update Complementary constraint (\(z_t\) and \(1-z_t\)) Independent control (\(f_t\) and \(i_t\) are separate)
Output filtering None (outputs \(h_t\) directly) Yes (output gate \(o_t\) filters)
Expressiveness Slightly weaker (limited by complementary constraint) Slightly stronger (independent gating is more flexible)
Performance on most tasks On par with LSTM On par with GRU
Small datasets Potentially better (fewer parameters, less prone to overfitting) May overfit
Very long sequences Slightly inferior Slightly superior (independent cell state is more stable)
Code complexity Simpler More complex

Rule of Thumb

In practice, the performance gap between GRU and LSTM is typically small. Jozefowicz et al. (2015) found in large-scale experiments that neither architecture dominates across all tasks. The general recommendation: start with GRU (faster), and switch to LSTM if the results are insufficient. For particularly long sequences or tasks requiring fine-grained control of information flow, LSTM may be more suitable.


Code Implementation (PyTorch)

Basic Usage

import torch
import torch.nn as nn

# Parameter settings
input_size = 10     # Input feature dimension
hidden_size = 20    # Hidden state dimension
num_layers = 2      # Number of GRU layers
batch_size = 3      # Batch size
seq_len = 5         # Sequence length

# Create GRU
gru = nn.GRU(
    input_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    batch_first=True,      # Input format: (batch, seq, feature)
    dropout=0.1,           # Inter-layer dropout (only effective when num_layers > 1)
    bidirectional=False     # Whether to use bidirectional GRU
)

# Input data
x = torch.randn(batch_size, seq_len, input_size)    # (3, 5, 10)
h0 = torch.zeros(num_layers, batch_size, hidden_size)  # (2, 3, 20)

# Forward pass
output, h_n = gru(x, h0)

print(f"output shape: {output.shape}")   # (3, 5, 20) - Output at each time step
print(f"h_n shape: {h_n.shape}")         # (2, 3, 20) - Hidden state of each layer at the last time step

nn.GRU Parameter Reference

Parameter Type Description
input_size int Feature dimension of input \(x_t\)
hidden_size int Dimension of hidden state \(h_t\)
num_layers int Number of stacked GRU layers, default 1
bias bool Whether to use bias terms, default True
batch_first bool If True, input shape is (batch, seq, feature), default False
dropout float Inter-layer dropout rate (not applied to the last layer), default 0
bidirectional bool Whether to use bidirectional GRU, default False

Output Reference

Output Shape Description
output (batch, seq, hidden_size * num_directions) \(h_t\) from the last layer at all time steps
h_n (num_layers * num_directions, batch, hidden_size) \(h_t\) from all layers at the last time step

GRU-Based Text Classification Model

class GRUClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_classes,
                 num_layers=1, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(
            embed_dim, hidden_size, num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        # Bidirectional GRU doubles the hidden state dimension
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        # x: (batch, seq_len) - token indices
        embeds = self.dropout(self.embedding(x))     # (batch, seq, embed_dim)
        output, h_n = self.gru(embeds)               # output: (batch, seq, hidden*2)

        # Take the output at the last time step (bidirectional concatenation)
        # Alternatively, use average pooling: out = output.mean(dim=1)
        out = output[:, -1, :]                        # (batch, hidden*2)
        out = self.dropout(out)
        logits = self.fc(out)                         # (batch, num_classes)
        return logits

# Usage example
model = GRUClassifier(
    vocab_size=10000, embed_dim=128,
    hidden_size=256, num_classes=2
)
x = torch.randint(0, 10000, (32, 50))  # batch=32, seq_len=50
logits = model(x)                        # (32, 2)

Practical Tip

When using bidirectional GRU, h_n has shape (num_layers * 2, batch, hidden_size). To extract and concatenate the hidden states from both directions of the last layer, use torch.cat([h_n[-2], h_n[-1]], dim=-1) rather than simply taking h_n[-1].


Discussion and Reflections

When Does GRU Outperform LSTM?

  1. Small datasets: GRU has approximately 25% fewer parameters, providing a natural regularization effect and making it less prone to overfitting
  2. Limited computational resources: On edge devices or in real-time inference scenarios, GRU's speed advantage is significant
  3. Moderate sequence lengths: When sequences range from tens to a few hundred steps, the difference between GRU and LSTM is minimal, making the simpler GRU the more sensible choice
  4. Rapid prototyping: GRU trains faster, making it suitable for quick experimental iterations

Why Is Their Performance So Similar?

The fundamental reason is that both address vanishing gradients through essentially the same mechanism — additive updates that provide a linear gradient path. The additional output gate and independent cell state in LSTM offer finer-grained control, but in most practical tasks, this extra expressiveness is not the performance bottleneck. Model performance is more heavily influenced by data quality, hyperparameter tuning, and regularization strategies.

GRU's Role in Modern Deep Learning

With the rise of the Transformer architecture, the use of both LSTM and GRU has declined significantly in NLP. However, GRU remains active in the following areas:

  • Time series forecasting: Financial data, sensor data, etc., where sequence lengths are moderate and GRU offers a good cost-performance ratio
  • Speech processing: Streaming processing in real-time speech recognition and speech synthesis
  • Reinforcement learning: Some policy networks use GRU to handle historical information in partially observable environments
  • Edge computing / embedded devices: Scenarios where model size and inference speed are critical
  • As a component in larger models: In some hybrid architectures, GRU is used as a local sequence modeling module

Historical Significance

GRU's greatest contribution lies not in performance improvement, but in demonstrating an important point: the complexity of LSTM is not necessary. This insight inspired subsequent researchers to continually explore simpler sequence modeling approaches, ultimately contributing to the emergence of new architectures like the Transformer. Simplicity is a value in itself.


Summary

Complete set of GRU equations (full forward pass, 4 formulas in total):

\[ r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \]
\[ z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \]
\[ \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) \]
\[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

Key takeaways:

  • GRU = a simplified version of LSTM: 2 gates instead of 3, 1 state instead of 2
  • The update gate \(z_t\) takes on the roles of both LSTM's forget gate and input gate (complementary constraint)
  • The reset gate \(r_t\) controls how much the candidate state depends on history
  • The additive update structure provides a linear gradient path, mitigating vanishing gradients
  • Parameter count is approximately 75% of LSTM, training is faster, and performance is generally on par

Next topic: Seq2Seq (Encoder-Decoder architecture, where either GRU or LSTM can serve as the base component)


评论 #