Skip to content

Transformer在量化中的应用

概述

Transformer 架构凭借自注意力机制 (Self-Attention Mechanism) 彻底革新了序列建模范式。相比 RNN/LSTM,Transformer 支持并行计算、能直接建模任意距离的依赖关系,在量化投资 (Quantitative Investment) 领域展现出强大的潜力——从时间序列预测到多资产选股 (Stock Selection),Transformer 正成为新一代量化模型的核心架构。

自注意力机制 (Self-Attention)

缩放点积注意力 (Scaled Dot-Product Attention)

给定查询 (Query) \(Q\)、键 (Key) \(K\)、值 (Value) \(V\)

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

其中 \(d_k\) 为键向量的维度。缩放因子 \(\sqrt{d_k}\) 防止点积值过大导致 softmax 梯度消失。

注意力权重矩阵 \(A = \text{softmax}(QK^T / \sqrt{d_k})\) 中,\(A_{ij}\) 表示位置 \(i\) 对位置 \(j\) 的关注程度。

多头注意力 (Multi-Head Attention)

\(Q, K, V\) 线性投影到 \(h\) 个子空间,并行计算注意力后拼接:

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O\]
\[\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]

多头注意力的金融解读

不同的注意力头可学习不同的市场模式:某些头关注短期动量,某些头捕捉长期均值回复,某些头则识别跨资产联动关系。

位置编码 (Positional Encoding)

Transformer 本身不具备序列顺序感知能力,需通过位置编码注入时序信息:

\[PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)\]

在金融场景中,也可使用可学习的位置编码 (Learnable Positional Embedding) 或基于时间间隔的编码 (Temporal Encoding)。

Temporal Fusion Transformer (TFT)

TFT 是专为时间序列预测设计的 Transformer 变体,包含以下关键组件:

变量选择网络 (Variable Selection Network)

自动学习每个输入变量的重要性权重:

\[v_t = \text{Softmax}(W_v \cdot \text{GRN}(x_t))\]

其中 GRN 为 Gated Residual Network。

时间注意力层

在编码器-解码器架构中使用 Interpretable Multi-Head Attention:

\[\hat{A} = \frac{1}{h}\sum_{i=1}^{h} A_i\]

平均注意力权重 \(\hat{A}\) 提供了模型对不同时间步关注程度的可解释性。

import torch
import torch.nn as nn

class TemporalAttentionBlock(nn.Module):
    def __init__(self, d_model=64, n_heads=4, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True
        )
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 因果掩码: 防止看到未来信息
        if mask is None:
            seq_len = x.size(1)
            mask = torch.triu(
                torch.ones(seq_len, seq_len), diagonal=1
            ).bool().to(x.device)
        attn_out, attn_weights = self.attention(
            x, x, x, attn_mask=mask
        )
        return self.norm(x + self.dropout(attn_out)), attn_weights

因果掩码 (Causal Mask)

在金融时间序列预测中,必须使用因果掩码确保模型在时刻 \(t\) 只能看到 \(t\) 及之前的信息。忽略这一点等同于引入前视偏差 (Look-ahead Bias)。

基于注意力的选股模型

架构设计

将选股问题建模为截面排序 (Cross-sectional Ranking) 任务:

class StockTransformer(nn.Module):
    def __init__(self, num_features, d_model=128, n_heads=8,
                 n_layers=3, seq_len=20, dropout=0.1):
        super().__init__()
        # 特征嵌入
        self.feature_embed = nn.Linear(num_features, d_model)
        self.pos_encoding = nn.Parameter(
            torch.randn(1, seq_len, d_model) * 0.02
        )
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=n_layers
        )
        # 预测头: 输出每只股票的预期收益率
        self.pred_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.GELU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        # x: (batch, seq_len, num_features)
        h = self.feature_embed(x) + self.pos_encoding
        causal_mask = torch.triu(
            torch.ones(x.size(1), x.size(1)), diagonal=1
        ).bool().to(x.device)
        h = self.transformer(h, mask=causal_mask)
        return self.pred_head(h[:, -1, :])  # 最后时刻的预测

训练策略

# 使用 ListNet 风格的排序损失
def listwise_loss(y_pred, y_true):
    """基于截面排序的损失函数"""
    pred_probs = torch.softmax(y_pred, dim=0)
    true_probs = torch.softmax(y_true, dim=0)
    return -torch.sum(true_probs * torch.log(pred_probs + 1e-8))

# 或使用 IC 作为损失的代理
def ic_loss(y_pred, y_true):
    """最大化预测与真实收益率的秩相关"""
    pred_rank = y_pred.argsort().argsort().float()
    true_rank = y_true.argsort().argsort().float()
    return -torch.corrcoef(torch.stack([pred_rank, true_rank]))[0, 1]

注意力可视化分析

def visualize_attention(model, sample_input, dates):
    """可视化注意力权重,理解模型的时间关注模式"""
    model.eval()
    with torch.no_grad():
        _, attn_weights = model.get_attention_weights(sample_input)
    # attn_weights: (n_heads, seq_len, seq_len)
    avg_attn = attn_weights.mean(dim=0)  # 跨头平均
    # 分析模型在预测时关注哪些历史时刻
    last_step_attn = avg_attn[-1, :]     # 最后一步对所有历史步的注意力

注意力与因子时效性

通过分析注意力权重的时间分布,可以揭示模型隐式学到的因子半衰期 (Factor Half-life)。如果注意力集中在近期(如最近 3-5 天),说明模型捕捉的是短期动量信号;若注意力呈分散分布,则可能融合了多尺度信息。

实践考量

挑战 应对策略
计算复杂度 \(O(n^2)\) 使用 Linear Attention 或限制序列长度
数据量不足 预训练 + 微调、数据增强
过拟合风险高 Dropout、权重衰减、早停
可解释性需求 注意力权重可视化、TFT 架构

小结

Transformer 在量化投资中的应用正从时间序列预测拓展到截面选股、多资产建模等更广泛的场景。自注意力机制不仅提供了强大的建模能力,还通过可视化注意力权重增强了模型的可解释性。TFT 等专用变体进一步融合了领域知识,是当前量化深度学习研究的前沿方向。