그래프 신경망에 대한 간략한 소개

그래프 신경망에 대한 간략한 소개

GNN(Graph Neural Network)은 그래프 구조의 데이터를 처리하기 위한 딥러닝 방법입니다. 이 자습서에서는 그래프 신경망의 기본 개념, 주요 모델, 응용 시나리오 및 코드 구현을 자세히 소개합니다.

그래프 신경망이란?

GNN(Graph Neural Networks)은 그래프 구조의 데이터를 처리하기 위한 신경망 모델의 일종으로 이미지와 같은 일반 데이터 구조를 처리하는 데 있어서 기존의 신경망(Convolutional Neural Network, Recurrent Neural Network 등)과 다릅니다 , 시간 순서), 그래프 신경망은 소셜 네트워크, 지식 그래프 등과 같은 불규칙한 그래프 구조의 데이터를 처리하는 데 특화되어 있습니다. 그래프 구조 데이터는 노드와 에지로 구성된 복잡한 관계형 네트워크로, 노드는 엔터티를 나타내고 에지는 엔터티 간의 관계를 나타냅니다. 기존의 신경망과 달리 그래프 신경망은 노드 간의 관계를 고려해야 하므로 노드와 에지를 표현하는 새로운 방식이 필요합니다.

그래프 신경망의 핵심 아이디어는 각 노드의 특징을 주변 노드의 특징과 결합하여 새로운 노드 표현을 형성하는 것입니다. 이 프로세스는 각 노드가 이웃 노드로부터 메시지를 수신하고 이러한 메시지를 새 노드 표현으로 집계하는 메시지 전달을 통해 구현될 수 있습니다. 이 방법은 보다 포괄적인 그래프 구조 정보를 얻기 위해 여러 번 반복될 수 있습니다.

그래프 신경망의 구조는 일반적으로 여러 계층으로 구성되며 각 계층에는 노드 임베딩, 메시지 전달 및 풀링과 같은 작업이 포함됩니다. 노드 임베딩 작업에서 각 노드의 기능은 신경망에서 쉽게 학습하고 처리할 수 있도록 저차원 벡터 표현으로 변환됩니다. 메시지 전달 작업에서 각 노드는 이웃 노드에 대한 정보를 수신하고 이러한 정보를 새 노드 표현으로 집계합니다. 풀링 작업에서 노드 표현은 그래프 수준 작업에 대한 예측을 용이하게 하기 위해 전체 그래프의 표현으로 병합됩니다.

현재 그래프 신경망은 소셜 네트워크 분석, 약물 발견, 추천 시스템 등과 같은 많은 분야에서 널리 사용되고 있습니다. 동시에 GCN(Graph Convolutional Network), GAT(Graph Attention Network) 등과 같이 다양한 작업 및 데이터 유형에 적응하기 위해 그래프 신경망 모델의 다양한 변형이 등장했습니다.

그래프 신경망의 기본 개념

알겠습니다. 각 섹션을 추가하겠습니다.

1. 그래프

그래프는 정점(Vertex)과 가장자리(Edge)를 사용하여 엔터티와 해당 관계를 나타내는 수학적 구조입니다. 그래프는 G = (V, E)로 나타낼 수 있습니다. 여기서 V는 정점 집합이고 E는 가장자리 집합입니다. 그래프에서 정점은 엔터티를 나타내고 가장자리는 엔터티 간의 관계를 나타냅니다. 예를 들어 소셜 네트워크에서 정점은 사용자를 나타내고 가장자리는 사용자 간의 팔로우 또는 친구 관계를 나타낼 수 있습니다.

그래프에는 유향 그래프와 무향 그래프의 두 가지 유형이 있습니다. 무방향 그래프에서는 간선에 방향이 없고, 유방향 그래프에서는 간선에 방향이 있습니다. 또한 그래프에는 엔터티 간의 관계 강도를 나타내는 가중치도 포함될 수 있습니다. 가중치가 있는 그래프를 종종 가중치 그래프라고 합니다.

2. 인접 행렬

인접 행렬 A는 그래프에서 정점 간의 관계를 나타내는 행렬입니다. V를 꼭지점 집합이라고 하고 A의 크기는 |V|×|V|입니다. 무방향 그래프의 경우 A의 요소 A(i, j) = 1은 꼭짓점 i와 꼭짓점 j가 인접한다는 즉, 간선이 있음을 의미하고, A(i, j) = 0은 꼭짓점 i와 꼭짓점 j가 서로 인접함을 의미합니다. 인접하지 않습니다. 유향 그래프의 인접 행렬은 유향 간선을 나타냅니다.

인접 행렬은 그래프 구조를 나타내는 데 사용할 수 있으며 그래프 알고리즘을 구현하는 데에도 사용할 수 있습니다. 인접 행렬을 통해 두 꼭짓점 사이에 가장자리가 있는지 여부와 가장자리의 유형 및 가중치와 같은 정보를 빠르게 쿼리할 수 있습니다.

또한, 인접행렬은 그래프의 표현이라고 볼 수 있으며, 인접행렬은 다른 표현(인접리스트, 에지리스트 등)에 비해 조밀한 그래프를 표현하는데 사용할 수 있으며, 공간활용도가 높고 쿼리 효율성.장점.

3. 그래프 신호

그래프 신호는 그래프 꼭지점에 정의된 신호로서 꼭지점의 특징이라고 할 수 있다. 정점 집합 V에 대해 그래프 신호를 |V|×d 행렬 X로 나타낼 수 있습니다. 여기서 d는 특징 차원입니다.

그래프 신호는 버텍스의 속성 또는 특징을 나타낼 수 있습니다.예를 들어 소셜 네트워크에서 각 버텍스는 성별, 나이, 직업 등과 같은 사용자 속성을 포함하는 벡터로 나타낼 수 있습니다. 화학 분자에서 각 원자는 전자 친화력, 전하 등과 같은 화학적 특성을 포함하는 벡터로 나타낼 수 있습니다.

그래프 신호는 그래프 분류, 클러스터링 및 예측과 같은 그래프 구조의 데이터를 분석하고 처리하는 데 사용할 수 있습니다. 일반적으로 그래프 구조 데이터의 특성을 더 잘 활용하려면 그래프 신호를 그래프의 위상 구조와 결합해야 합니다. 예를 들어 그래프 컨볼루션 신경망에서 그래프 신호와 인접 행렬의 컨볼루션 연산을 통해 노드의 특징 표현을 추출할 수 있습니다.

4. 그래프 컨벌루션

그래프 컨볼루션은 전통적인 컨볼루션의 개념을 그래프 구조의 데이터로 확장한 작업입니다. 그래프 컨벌루션은 일반적으로 인접 행렬 A와 그래프 신호 X: f(A, X) 사이의 함수로 표현됩니다. 이 기능을 통해 정보 전달 및 특징 추출을 그래프에서 구현할 수 있습니다. 기존 컨볼루션과 달리 그래프 컨볼루션은 그래프에서 노드의 토폴로지를 고려하므로 노드 간의 관계와 종속성을 캡처합니다.

그래프 컨벌루션을 구현하는 방법에는 여러 가지가 있으며 가장 일반적인 방법은 스펙트럼 도메인 기반 방법과 공간 도메인 기반 방법입니다. 스펙트럼 영역 기반 방법은 컨볼루션 연산을 정의하기 위해 그래프의 라플라시안 행렬의 고유값과 고유벡터를 활용합니다. 공간 도메인 기반 방법은 컨볼루션 작업을 위해 인접 행렬 및 노드 기능을 활용합니다.

그래프 컨벌루션은 노드 분류, 그래프 분류, 링크 예측 등과 같은 그래프 구조 데이터에 대한 많은 작업에 사용할 수 있습니다. 그래프 신경망에서 그래프 컨볼루션은 그래프 구조 데이터에서 특징 표현을 추출하여 보다 효율적이고 정확한 그래프 구조 데이터 분석 및 처리를 가능하게 하는 핵심 작업입니다.

메인 그래프 신경망 모델

1. GCN(그래프 컨벌루션 네트워크)

GCN은 Spectral Domain을 기반으로 하는 그래프 컨벌루션 방법입니다. 그래프의 Laplacian 행렬을 사용하여 특징 벡터에 컨볼루션 연산을 수행하여 정보 전달 및 특징 추출을 수행합니다. GCN의 핵심 아이디어는 인접 행렬 A와 그래프 신호 X의 곱을 통해 정보 전송을 실현하는 것입니다.

GCN에서 각 노드의 고유 벡터는 이웃 노드의 고유 벡터로 가중되고 평균화됩니다. 가중치는 인접 행렬 A의 값에 의해 결정됩니다. 구체적으로 GCN의 컨볼루션 연산은 다음과 같이 표현할 수 있습니다.

Z = f(A, X) = D⁻¹ A X W

여기서 D는 A의 차수 행렬이고 W는 학습 가능한 가중치 행렬입니다.

GCN은 노드 분류, 그래프 분류, 링크 예측 및 기타 작업에 널리 사용되었으며 좋은 결과를 얻었습니다. 그러나 GCN의 한계는 컨볼루션 연산이 1차 이웃 노드만 고려하고 더 긴 범위의 관계와 전역 정보를 캡처할 수 없다는 것입니다. 따라서 후속 연구에서는 아래에 소개된 GAT 및 GraphSAGE와 같은 그래프 컨벌루션 네트워크의 많은 개선된 버전을 제안했습니다.

2. GAT(그래프 어텐션 네트워크)

GAT는 Spatial Domain 기반의 그래프 컨볼루션 방법입니다. GAT는 어텐션 메커니즘을 도입하여 모델이 서로 다른 에지에 서로 다른 가중치를 할당하여 노드 간의 관계를 더 잘 포착할 수 있도록 합니다. GAT에서 각 노드의 특징 벡터는 이웃 노드의 특징 벡터로 가중 및 평균화되며 가중치는 어텐션 메커니즘에 의해 계산됩니다.

구체적으로 GAT의 컨볼루션 연산은 다음과 같이 표현할 수 있습니다.

Z = f(A, X) = CONCAT(ATTENTION(A, X)W)

그 중 ATTENTION은 어텐션 가중치를 계산하는 함수이고 CONCAT은 연결 연산이다. ATTENTION 함수는 일반적으로 두 단계를 포함합니다. 먼저 각 인접 노드와 현재 노드 간의 유사도를 계산한 다음 softmax 함수를 사용하여 유사도를 가중치로 변환합니다. 이러한 방식으로 각 이웃 노드의 가중치가 다를 수 있으므로 노드 간의 관계를 더 잘 표현할 수 있습니다.

GAT의 어텐션 메커니즘은 모델이 다양한 그래프 구조 및 작업에 더 잘 적응할 수 있도록 하며 그래프 분류, 노드 분류 및 링크 예측과 같은 많은 작업에서 좋은 결과를 얻었습니다.

3. GraphSAGE(그래프 샘플 및 AggregatE)

GraphSAGE는 모델이 대규모 그래프 데이터를 처리할 수 있도록 이웃 샘플링 및 집계 전략을 제안하는 공간 도메인 기반 그래프 컨볼루션 방법입니다. GraphSAGE에서 각 노드의 고유 벡터는 이웃의 고유 벡터와 함께 집계됩니다. GCN 및 GAT와 달리 GraphSAGE는 모든 이웃 노드를 평균화하거나 가중 평균하지 않고 특정 이웃 샘플링 전략을 채택하고 집계를 위해 노드의 1차 또는 k차 이웃 노드만 고려하여 계산 복잡성을 줄입니다.

구체적으로 GraphSAGE의 컨볼루션 작업은 다음과 같이 표현할 수 있습니다.

Z = f(A, X) = AGGREGATE(NEIGHBORS(A, X)) W

그 중 NEIGHBORS는 이웃 샘플링 함수로 노드의 이웃 중에서 일부 노드를 집계의 입력으로 무작위로 샘플링하는 데 사용되며, AGGREGATE는 다음과 같은 이웃 노드의 특징 벡터를 집계하는 데 사용되는 집계 함수입니다. 평균값, 최대값 등 W는 학습 가능한 의 가중치 매트릭스입니다.

GraphSAGE는 대규모 그래프 데이터를 처리할 수 있으며 노드 분류, 그래프 분류 및 링크 예측과 같은 작업에서 좋은 결과를 얻었습니다. 이웃 샘플링 및 집계 전략도 차용되어 많은 후속 그래프 컨벌루션 네트워크에서 개선됩니다.

그래프 신경망의 응용 시나리오

그래프 신경망은 주로 다음과 같은 많은 분야에서 널리 사용됩니다.

  1. 노드 분류: 소셜 네트워크에서 사용자 관심 태그 예측과 같이 그래프의 노드 범주를 예측합니다.
  2. 링크 예측: 지식 그래프에서 엔터티 간의 관계를 예측하는 등 그래프에서 노드 간 에지가 있는지 예측합니다.
  3. 그래프 분류: 생체 분자 네트워크에서 분자의 활동을 예측하는 등 전체 그래프의 범주를 예측합니다.
  4. 그래프 생성: 특정 토폴로지 특성을 만족하는 네트워크 생성과 같은 특정 속성을 가진 그래프를 생성합니다.

암호

다음은 PyTorch Geometric 라이브러리를 사용하여 구현된 Cora 데이터 세트를 사용하여 그래프 컨벌루션 신경망을 교육하기 위한 예제 코드입니다.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

# 判断是否有可用的GPU,如果有就使用GPU,否则使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载Cora数据集
dataset = Planetoid(root='data/Cora', name='Cora')

# 定义一个GCN模型
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        # 第一个图卷积层
        self.conv1 = GCNConv(input_dim, hidden_dim)
        # 第二个图卷积层
        self.conv2 = GCNConv(hidden_dim, output_dim)
        # 全连接层
        self.fc = nn.Linear(output_dim, dataset.num_classes)
        
    def forward(self, x, edge_index):
        # 第一次图卷积,使用ReLU激活函数
        x = F.relu(self.conv1(x, edge_index))
        # 第二次图卷积
        x = self.conv2(x, edge_index))
        # Dropout操作,防止过拟合
        x = F.dropout(x, training=self.training)
        # 全连接层
        x = self.fc(x)
        # 使用log_softmax进行分类
        return F.log_softmax(x, dim=1)

# 创建GCN模型实例,并将模型移动到设备上(GPU或CPU)
model = GCN(dataset.num_features, 16, dataset.num_classes).to(device)

# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# 定义训练函数
def train(model, optimizer, criterion, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x.to(device), data.edge_index.to(device))
    loss = criterion(out[data.train_mask], data.y[data.train_mask].to(device))
    loss.backward()
    optimizer.step()

# 定义测试函数
def test(model, data):
    model.eval()
    out = model(data.x.to(device), data.edge_index.to(device))
    pred = out.argmax(dim=1)
    acc = pred[data.test_mask].eq(data.y[data.test_mask].to(device)).sum().item() / data.test_mask.sum().item()
    return acc

# 进行模型训练和测试,并输出测试集准确率
for epoch in range(200):
    train(model, optimizer, criterion, dataset[0])
    test_acc = test(model, dataset[0])
    print('Epoch: {:03d}, Test Acc: {:.4f}'.format(epoch, test_acc))

위의 코드에서 먼저 PyTorch 관련 라이브러리와 Planetoid데이터 세트 클래스 및 GCNConv그래프 컨볼루션 레이어 클래스를 가져옵니다. 그런 다음 사용 가능한 GPU가 있는지 확인합니다. 그렇다면 GPU를 사용하고 그렇지 않으면 CPU를 사용합니다. 다음으로, Cora 데이터 세트를 로드하고 GCN그래프 컨벌루션 신경망 모델을 만들기 위한 클래스를 정의합니다. 클래스 에서는 GCN2개의 그래프 컨벌루션 레이어와 완전 연결 레이어를 정의하고 forward모델의 순방향 전파를 메서드에서 완료합니다. 다음으로 인스턴스화된 모델이 생성되어 장치(GPU 또는 CPU)로 이동됩니다. 그런 다음 옵티마이저 및 손실 함수와 훈련 및 테스트 함수가 정의됩니다. 마지막으로 모델은 루프에서 훈련 및 테스트되며 테스트 세트 정확도가 출력됩니다.

요약하다

이 튜토리얼에서는 GNN(Graph Neural Network)의 기본 개념, 주요 모델 및 애플리케이션 시나리오와 PyTorch 및 PyTorch Geometric을 사용하여 GCN을 구현하기 위한 샘플 코드를 소개합니다. 그 중 그래프는 엔터티와 이들의 관계를 나타내기 위해 꼭지점과 모서리를 사용하는 수학적 구조이고, 인접행렬은 그래프에서 꼭지점 간의 관계를 나타내는 행렬이며, 그래프 신호는 그래프 꼭지점에 정의된 신호이다. , 그래프 컨볼루션은 그래프 구조로 확장된 전통적인 컨볼루션 작업입니다. 주요 GNN 모델에는 스펙트럼 도메인 기반의 GCN, 공간 도메인 기반의 GAT 및 GraphSAGE가 포함됩니다. GNN은 노드 분류, 링크 예측, 그래프 분류 및 그래프 생성과 같은 분야에서 널리 사용됩니다. 이 튜토리얼의 학습을 통해 독자는 GNN을 더 잘 이해하고 실제 문제에서 문제를 해결하기 위해 GNN을 적용할 수 있습니다.

추천

출처blog.csdn.net/qq_36693723/article/details/130856632