跳转至

GraphSAGE

GraphSAGE(SAmple and aggreGatE)是 Hamilton et al. 在 2017 年提出的归纳式图神经网络框架。与 GCN 的转导学习方式不同,GraphSAGE 通过采样邻居聚合特征的方式来生成节点嵌入,使其能够泛化到训练时未见过的新节点,是工业界大规模图学习的重要基石。

学习路线: GCN 基础 → 转导 vs 归纳学习 → 采样策略 → 聚合函数 → 小批量训练 → 工业应用


GraphSAGE 概述

归纳学习 vs 转导学习

GCN 等早期图神经网络面临一个关键局限:它们是转导式的——在训练时需要看到整张图(包括测试节点),无法处理动态变化的图结构。

维度 转导学习(Transductive) 归纳学习(Inductive)
训练时可见 整张图(含测试节点的特征) 仅训练子图
新节点处理 需要重新训练整个模型 可直接对新节点推断
代表方法 GCN、DeepWalk GraphSAGE、GAT
适用场景 静态图、固定节点集 动态图、不断有新节点加入
工业可行性 难以扩展到大规模图 可扩展,适合工业部署

GraphSAGE 的核心思想:学习的不是每个节点的固定嵌入,而是一个聚合邻居信息的函数。这个函数可以应用到任何新节点上,只要它有邻居结构和特征。

与 GCN 的关键区别

GCN 的传播规则对所有节点同时更新

\[ H^{(l+1)} = \sigma\left(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{(l)}W^{(l)}\right) \]

而 GraphSAGE 对每个节点独立地执行采样-聚合操作,不依赖全图的邻接矩阵。


采样与聚合框架

算法流程

GraphSAGE 的前向传播分为三步:

第 1 步:邻居采样

对每个目标节点 \(v\),从其邻居集 \(\mathcal{N}(v)\)均匀随机采样固定数量 \(S\) 个邻居:

\[ \mathcal{N}_S(v) = \text{SAMPLE}(\mathcal{N}(v), S) \]

固定采样数量保证了计算复杂度可控,不受高度节点影响。

第 2 步:聚合邻居信息

用聚合函数将采样邻居的特征汇聚成单一向量:

\[ \mathbf{h}_{\mathcal{N}(v)}^{(l)} = \text{AGGREGATE}^{(l)}\left(\left\{\mathbf{h}_u^{(l-1)}, \forall u \in \mathcal{N}_S(v)\right\}\right) \]

第 3 步:更新节点表示

将节点自身的特征与聚合后的邻居信息拼接,经过线性变换和非线性激活:

\[ \mathbf{h}_v^{(l)} = \sigma\left(W^{(l)} \cdot \text{CONCAT}\left(\mathbf{h}_v^{(l-1)}, \mathbf{h}_{\mathcal{N}(v)}^{(l)}\right)\right) \]

最后对嵌入进行 L2 归一化:\(\mathbf{h}_v^{(l)} \leftarrow \frac{\mathbf{h}_v^{(l)}}{\|\mathbf{h}_v^{(l)}\|_2}\)

邻居采样策略

策略 描述 优缺点
均匀采样 等概率选取邻居 简单高效,GraphSAGE 默认方案
重要性采样 根据节点重要性加权采样 更好的近似,但计算开销更大
全邻居(无采样) 使用所有邻居 精确但无法扩展到大图

多跳采样:对于 \(L\) 层的 GraphSAGE,每层采样 \(S_l\) 个邻居。总计算量为 \(\prod_{l=1}^{L} S_l\)。常见设置:\(L=2\)\(S_1=25\)\(S_2=10\)(即 2 跳内最多 250 个节点)。

聚合函数

GraphSAGE 提出了三种聚合函数:

Mean Aggregator

\[ \mathbf{h}_{\mathcal{N}(v)}^{(l)} = \text{MEAN}\left(\left\{\mathbf{h}_u^{(l-1)}, \forall u \in \mathcal{N}_S(v)\right\}\right) \]

这与 GCN 的传播规则最为相似,是最简单的选择。

LSTM Aggregator

将邻居特征序列输入 LSTM,使用最终隐藏状态作为聚合结果。由于邻居没有天然顺序,需要随机排列邻居后再输入 LSTM。

Pool Aggregator

\[ \mathbf{h}_{\mathcal{N}(v)}^{(l)} = \max\left(\left\{\sigma\left(W_{\text{pool}} \mathbf{h}_u^{(l-1)} + \mathbf{b}\right), \forall u \in \mathcal{N}_S(v)\right\}\right) \]

先对每个邻居做非线性变换,再取逐元素最大值。

聚合函数 排列不变性 表达能力 计算效率
Mean 中等
LSTM 否(依赖排列)
Pool

小批量训练

计算图构建

GraphSAGE 的一大优势是支持小批量(mini-batch)训练,无需将整张图加载到内存:

  1. 从训练集中采样一批目标节点 \(\mathcal{B}\)
  2. 对每个目标节点,递归采样 \(L\) 跳邻居,构建计算图
  3. 从最外层向最内层逐层聚合,最终得到目标节点的嵌入
目标节点 v 的 2 层计算图示意:

第 2 层采样 (S₂=2):     a1  a2    b1  b2    c1  c2
                          \  /      \  /      \  /
第 1 层采样 (S₁=3):       n1        n2        n3
                            \        |        /
目标节点:                          v

邻居扩展问题

多层采样会导致邻居扩展(Neighbor Explosion)问题:\(L\) 层采样的总节点数为 \(O(\prod_l S_l)\),指数增长。

缓解策略:

  • 限制采样数量:每层采样数不宜过大(通常 \(S \leq 25\)
  • 减少层数:通常 \(L = 2\) 即可,过深反而引起过度平滑
  • 子图采样(ClusterGCN, GraphSAINT):先采样子图,再在子图内做全邻居聚合

损失函数

监督学习(节点分类):使用交叉熵损失。

无监督学习(学习节点嵌入):GraphSAGE 原始论文使用了基于图结构的损失,鼓励相邻节点嵌入相似、不相邻节点嵌入远离:

\[ J(\mathbf{z}_v) = -\log\left(\sigma(\mathbf{z}_v^T \mathbf{z}_u)\right) - Q \cdot \mathbb{E}_{v_n \sim P_n(v)}\left[\log\left(\sigma(-\mathbf{z}_v^T \mathbf{z}_{v_n})\right)\right] \]

其中 \(u\)\(v\) 的邻居,\(v_n\) 是负采样节点,\(Q\) 是负采样数量。


与 GCN 的对比

维度 GCN GraphSAGE
学习范式 转导式 归纳式
邻居使用 全邻居(完整邻接矩阵) 采样固定数量邻居
训练方式 全图训练(full-batch) 小批量训练(mini-batch)
聚合方式 加权平均(归一化邻接矩阵) 可选(Mean/LSTM/Pool)
节点自身信息 通过自环隐式包含 显式拼接
可扩展性 受限于 GPU 内存 可扩展到百万级节点
新节点推断 需重新训练 直接推断

应用场景

社交网络

  • 用户分类:基于社交关系和用户属性,预测用户兴趣标签
  • 社区发现:无监督 GraphSAGE 学习社区结构
  • 欺诈检测:识别异常的社交行为模式

推荐系统

Pinterest 的 PinSage(GraphSAGE 的工业扩展版)是最成功的工业应用之一:

  • 在包含 30 亿节点、180 亿边的图上运行
  • 使用随机游走采样邻居(而非均匀采样)
  • 引入重要性采样和硬负样本挖掘
  • 推荐效果相比之前方法提升 40%+
应用领域 图的构建 节点 任务
社交网络 用户关系图 用户 好友/关注 节点分类
推荐系统 用户-物品二部图 用户、物品 交互行为 链接预测
学术网络 引文网络 论文 引用关系 节点分类
生物信息 蛋白质交互网络 蛋白质 相互作用 功能预测

GraphSAGE 通过采样-聚合的设计范式,成功将图神经网络从学术研究推向了工业实践,是理解现代图学习系统的必备基础。


评论 #