PyG使用scatter对图节点特征进行聚合

对于做图任务时,有时我们需要汇聚整张图的节点特征,例如一张图共有5个节点,每个节点有自己的 feature,如果要对图进行分类,那么我们就需要用一个 Graph Embedding 来表示这张图,也就是用一个向量来表示这张图,然后将 Graph Embedding 交给下游的分类任务网路中进行分类。

这个简单说一点有点像池化操作,例如最大池化、平均池化等,对于平均池化来讲,我们可以在定义模型结构时,依次遍历每张图的所有节点特征,然后添加计算均值向量。

对于这个需求,在PyG中已经为我们实现好了相关的模块来实现这个需求,这个模块就是 torch_scatter

常用参数:

  • src:需要池化的节点的特征矩阵
  • index:对应索引信息
  • dim:在哪个维度上进行池化
print(dataset[0].x.shape)
>>

猜你喜欢

转载自blog.csdn.net/m0_47256162/article/details/128750752