背景
- 遥感图像中各地理对象的特征总是与其邻接对象有密切的关系,期望充分利用地理对象的空间关系
- 想把地理对象间的空间关系嵌入到遥感对象识别网络中,计算对象与邻接对象构成的场景图与标签图之间的相似度。
- 计算图相似度的函数有很多,由于实验过程中觉得图核函数(Graph Kernel)难以满足复杂的地理环境,因此考虑采用GCN来计算
因此,在如图所示的场景中,以船或者海岸为中心对象,构建其与邻接对象的图结构,并通过Backbone以及对象掩码提取对象特征后,计算两个属性图的相似度。
思路
- 用GCN根据节点特征和邻接矩阵做特征变换,统一维度
- 将两个图的特征在通道维度拼接,通过线性层得到1维相似度值
- 每个训练样本对应n个地理对象,跟邻接对象一起在遥感场景图以及标签图中构成n对图数据,遍历求解
代码
Step1 导入相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
Step2 GCN
class GraphConvolution(nn.Module):
"""
inputs: node_features(n,d) 、 adj(n,n)
outputs: node_feature(n,d)
"""
def __init__(self, in_features, out_features):
super(GraphConvolution, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self.relu = nn.ReLU()
def forward(self, x, adj):
# x: Node features (num_nodes, in_features)
# adj_matrix: Adjacency matrix (num_nodes, num_nodes)
x = self.linear(x)
# Normalize by the degree matrix (optional step)
degree_matrix = torch.sum(adj, dim=1, keepdim=True)
x = x / degree_matrix
# Apply activation function (e.g., ReLU)
x = self.relu(x)
# Perform graph propagation
x = torch.matmul(adj, x)
return x
Step3 计算两个图的相似度
class GraphSimilarity(nn.Module):
"""
计算两个图的相似度
inputs:
node_features_left: (n,d1)
node_features_right:(n,d2)
adj:(n,n)
outputs:
similarity: (n,1)
"""
def __init__(self, num_features_left, num_features_right, hid_features):
super(GraphSimilarity, self).__init__()
self.gcn1_left = GraphConvolution(num_features_left, hid_features)
self.gcn2_left = GraphConvolution(hid_features, hid_features)
self.gcn1_right = GraphConvolution(num_features_right, hid_features)
self.gcn2_right = GraphConvolution(hid_features, hid_features)
self.linear = nn.Linear(hid_features * 2, 1)
def forward(self, node_features_left, node_features_right, adj):
left = self.gcn1_left(node_features_left, adj)
left = self.gcn2_left(left, adj)
right = self.gcn1_right(node_features_right, adj)
right = self.gcn2_right(right, adj)
# similarity = F.cosine_similarity(left, right, dim=-1) #余弦相似度
# similarity = torch.sum(left * right) / (torch.norm(left) * torch.norm(right))# 点积
# 将两个图的嵌入向量拼接
similarity = self.linear(torch.cat((left, right), dim=-1)).mean()
return similarity
Step4 遍历单个样本对象的邻接图
def calculate_similarity(model, feature1, feature2, adj):
"""
输入:节点特征1、节点特征2(节点标签)、邻接矩阵
:param feature1: tensor (n, d1) 此处n表示一个样本中的节点个数,而非其邻接节点个数 (模型输出)
:param feature2: tensor (n, d2) 此处n表示一个样本中的节点个数,而非其邻接节点个数 (obj标签)
:param adj: tensor (n,n)
:return: tensor (n,1)
"""
"(1)获取当前对象的邻接对象,构成图,n个对象构成n个图"
"(2)把当前图和标签图送入网络算相似度"
sim_matrix = torch.zeros(size=(feature1.shape[0],))
# 遍历每个对象
for i in range(feature1.shape[0]):
indices = torch.where(adj[i] == 1)[0] # 拿到邻接对象索引
fea1 = feature1[indices,:]
fea2 = feature2[indices,:]
adj_ = adj[indices,:][:,indices].float() # 当前节点的邻接矩阵
sim = model(fea1, fea2, adj_)
sim_matrix[i] = sim
return sim_matrix
Step5 测试结果
if __name__ == '__main__':
n = 20 # 对象个数
d1 = 64 # 图1节点特征维度
d2 = 1 # 图2节点特征维度
embedding_size = 128 # 统一维度
outputs = torch.randn(size=(n,d1))
labels = torch.randn(size=(n,d2))
adj = torch.randint(0,2,size=(n,n))
model = GraphSimilarity(d1, d2, embedding_size)
sim = calculate_similarity(model, outputs, labels, adj)
print("Graph Similarity:", sim)
print(sim.shape)
n个对象得到长度为n的向量,其中每个值表示由当前节点与邻接节点组成的图1(例如遥感场景图)和图2(例如对象标签图)的相似度
模型结构
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [1, 20, 128] 8,320
ReLU-2 [1, 20, 128] 0
GraphConvolution-3 [1, 20, 128] 0
Linear-4 [1, 20, 128] 16,512
ReLU-5 [1, 20, 128] 0
GraphConvolution-6 [1, 20, 128] 0
Linear-7 [1, 20, 128] 256
ReLU-8 [1, 20, 128] 0
GraphConvolution-9 [1, 20, 128] 0
Linear-10 [1, 20, 128] 16,512
ReLU-11 [1, 20, 128] 0
GraphConvolution-12 [1, 20, 128] 0
Linear-13 [1, 20, 1] 257
================================================================
Total params: 41,857
Trainable params: 41,857
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 39.06
Forward/backward pass size (MB): 0.23
Params size (MB): 0.16
Estimated Total Size (MB): 39.46
----------------------------------------------------------------
小结
- 记录一下简单用GCN算场景图和对应标签图相似度的模块
- 由于每个样本中地理对象个数不同,且满足条件的对象个数也不同,因此在模型输入的时候就去掉了batchsize维度,即每次都是单张样本送入网络计算相似度。
- 还有很多更复杂的网络结构用于类似的需求,有待进一步尝试