目录
基本信息
作者:Zhanghao Wu, Paras Jain, Matthew A. Wright, Azalia Mirhoseini, Joseph E. Gonzalez, Ion Stoica
机构:UC Berkeley, Google Brain
摘要
本文提出了一种新的图神经网络架构Graph Transformer (GraphTrans),以学习图中的长程依赖关系。通过在标准GNN模块之后添加一个Transformer子模块,GraphTrans可以显式地计算图中所有节点对之间的关系,从而学习全局信息。实验结果表明,GraphTrans在多个图分类任务上都取得了state-of-the-art的结果,显著优于那些试图通过分层聚类学习长程依赖的方法。
Code for GraphTrans is available at https://github.com/ucbrise/graphtrans
方法
图神经网络(GNNs)是一种强大的结构数据表示学习方法。但是当前的GNN方法在学习长程依赖上存在困难。简单增加GNN的深度和宽度无法扩大感受野,因为更大的GNN会遇到梯度消失和过度平滑等问题。基于pooling的方法如分层聚类虽然在理论上可以学习更广范围的信息,但其效果还不如在计算机视觉任务中那样普适。
最近的一些研究表明,在计算机视觉任务中,注意力机制可以取代卷积操作,学习相似的局部关系。在更高级的任务中,去掉结构先验的模块反而获得了更好的效果,提示结构先验对于建模长程依赖可能是无用或者有害的。
启发于这一发现,本文提出了Graph Transformer (GraphTrans),使用GNN子模块学习局部的短程关系,使用Transformer子模块学习长程的全局关系。删掉位置编码使得Transformer对图的节点顺序不敏感,因此适合用于建模图结构。
本文的实验结果表明,相比那些试图编码结构先验的Baseline,简单的GraphTrans架构取得了多个图分类任务上的最优结果。这表明与GNN不同,对图的长程依赖建模,使用纯基于学习的方法而不强制编码图结构信息可能是更合适的。
模型架构
GraphTrans由两个主要模块组成:
GNN子模块:用于学习节点的局部邻域信息
Transformer子模块:用于学习全局的长程依赖关系
GNN子模块可以是任意现有的图卷积网络,用来学习每个节点的局部表征。
Transformer子模块在GNN模块之后,对GNN模块输出的节点表征进行全局的自注意力计算,学习节点之间的全局关系。这里的Transformer使用的是无位置编码的结构,以保证对图节点的permutation invariance。
最后,通过一个特殊的CLS标记,将Transformer模块学习到的全局信息聚合为整个图的表征,进行图分类。
Transformer模块
-
-
计算自注意力ighed attentions:
-
-
多头注意力机制
-
Feed Forward子层:Dropout → Layer Norm → 全连接 → non-linearity → Dropout → 全连接 → Dropout → Layer Norm
-
残差连接
CLS标记作为图表示
-
在节点特征序列后面追加一个可学习的CLSembedding $h_{}$
-
Transformer输出对应的$h_{}^{L_{TF}}$作为整个图的表征
-
通过线性层和softmax生成预测
实验结果
-
在OpenGraphBenchmark的多个数据集上,GraphTrans都取得了SOTA的结果,明显优于那些通过分层聚类学习长程依赖的Baseline
-
在NCI生物分子数据集上也取得明显提升
-
即使不使用位置编码也能有效学习长程依赖
-
CLS标记优于mean/max-pooling等其它图表示方法
图2展示了一个来自OGB Code2数据集的示例图,以及经过GraphTrans模型后的注意力图(attention map)。
横轴对应目标节点,纵轴对应源节点,所以每一行的注意力权重和为1。
可以观察到:
- 节点17向节点8赋予了较高的注意力权重,尽管这两个节点在图中相距5个跳数
- 第18列对应CLS标记的embedding,可以看到多个节点向其赋予较高的注意力权重,说明这些节点正在向CLS标记传递全局信息
这表明Transformer模块内的全局自注意力可以捕获图中长程的依赖关系,哪怕两个节点在图中路径长度很长也可以建立起联系。而CLS标记也的确学会了从各个节点汇聚信息,获得整个图的表达。这与Transformer在NLP任务中的表现非常相似,每个token向CLS标记传递全局语义信息。所以这张注意力图从视觉上佐证了文章所述,Transformer可以高效地学习图上的长程依赖。
结论
-
本文提出了GraphTrans,一种简单有效的GNN架构,通过添加Transformer模块来学习长程依赖
-
Transformer的全局自注意力机制可以显式地学习节点之间的全局关系
-
CLS标记为图学习提供了一种有效的表示方法
-
GraphTrans在多个图分类任务上都取得了state-of-the-art的结果
-
结果表明,与GNN不同,Transformer可以适用于学习图上的长程依赖
-
GraphTrans提供了一种简单通用的方式来提升GNN在图分类任务上的精度
文章评论