从代码角度深入浅出图神经网络系列笔记(二)

前言

这一节笔记中主要针对继承InMemoryDataset,一次性加载所有的数据到内存,这种数据集一般不是很大,所以直接一次性加载完毕

构建数据集

1、Dataset

pytorch geometric 构建数据集分两种:
1、继承InMemoryDataset,一次性加载所有的数据到内存
2、继承Dataset,分次加载到内存

在自定义的Dataset的初始化方法种传入数据存放的路径,然后pytorch geometric 会在这个路径下再划分2个文件夹:
1、raw_dir:存放原始数据的路径(一般是csv、mat等格式)
2、processed_dir:存放处理后的数据(一般pt格式,由process方法实现)
但是pytorch中,实际上是没有这两个文件夹的

来看官方文件:

https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

在这里插入图片描述
在示例代码第二行就引入了InMemoryDataset函数,首先我们去看下,这个函数的使用以及参数

2、InMemoryDataset解读

在这里插入图片描述

  • root是数据集存储的根目录
  • tansformpre_transform 有相同有不同,相同点是,都是一个用于接受数据并返回转换后版本的数据;不同点是,tansform在每次访问前转化,pre_transform是保存到磁盘之前进行转化
  • pre_filter 是一个用于接受数据并返回布尔值的函数,用于指示数据对象是否应该保存在最终数据集中

3、官方文档例子

再回来继续看代码,我把说明整合到代码的注释了,另外,这里有一些地方视频中解释的不是很清楚,我结合文章 Hands-on Graph Neural Networks with PyTorch & PyTorch Geometric 增加了一些注释以及自己的理解。

# 官方代码 https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

import torch
from torch_geometric.data import InMemoryDataset # https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html  CLASS InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None): # 初始化函数
        super(MyOwnDataset, self).__init__(root, transform, pre_transform) # super用于说明MyOwnDataset继承InMemoryDataset初始化结果
        self.data, self.slices = torch.load(self.processed_paths[0]) # 详见说明1

    @property # 修饰方法,使方法可以像属性一样访问(保护变量/只读函数转变)详见说明2
    def raw_file_names(self): # 返回一个包含没有处理的数据的名字的list
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self): # 返回一个包含所有处理过的数据的list
        return ['data.pt']

    def download(self): # 下载数据集函数,不需要的话直接填充pass
        # Download to `self.raw_dir`.

    # 整合你的数据成一个包含data的list,然后调用 self.collate()去计算将用于 DataLodadr 的片段
    def process(self):
        # Read data into huge `Data` list.
        data_list = [...] # 创建并读取了数据的列表

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)] # 判断数据对象是否应该保存

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list] # 保存到磁盘之前进行转化

        data, slices = self.collate(data_list)# 将数据对象的python列表整理为内部存储格式 torch_geometric.data.InMemoryDataset
        torch.save((data, slices), self.processed_paths[0])

1、说明1:

这部分参考了pytorch_geometric自制数据集
制作数据集需要定义dataslices

  • data指的是以pytorch_geometric定义的数据类型Data构建的图数据集;
  • slices指的是切片,即数据集中不同graph的划分,如slices[‘x’]=[0,75,150]指的是数据集中按照75个节点划分,共三个图,slices[‘y’]slices['edge_index ']以此类推。slices用于区分不同的graph与实现shuffle等功能。值得注意slices需要inttensor类型,否则DataLoader不支持切片操作。

2、说明2:

这部分参考了python @property的介绍与使用
我觉得这个文章说明的比up主找的解析更容易理解一些,已经很简洁了,我就不摘到我的文章中了,大家仔细前往观看即可

亚马逊代码例子

# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/datasets/amazon.py
import torch
from torch_geometric.data import InMemoryDataset, download_url # download_url为了下载数据
from torch_geometric.io import read_npz 


class Amazon(InMemoryDataset):
    r"""The Amazon Computers and Amazon Photo networks from the
    `"Pitfalls of Graph Neural Network Evaluation"
    <https://arxiv.org/abs/1811.05868>`_ paper.
    Nodes represent goods and edges represent that two goods are frequently
    bought together.
    Given product reviews as bag-of-words node features, the task is to
    map goods to their respective product category.
    Args:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"Computers"`,
            :obj:`"Photo"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """

    url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'

    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name.lower() # lower将字符串所有大小转小写
        assert self.name in ['computers', 'photo'] # 利用断言判断 name 值的范围是不是在 computers/photo 范围内
        super(Amazon, self).__init__(root, transform, pre_transform) # 继承初始化值
        self.data, self.slices = torch.load(self.processed_paths[0])  

    @property
    def raw_file_names(self):
        return 'amazon_electronics_{}.npz'.format(self.name)

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        download_url(self.url + self.raw_file_names, self.raw_dir)

    def process(self):
        data = read_npz(self.raw_paths[0]) # 读取 npz 格式数据集
        data = data if self.pre_transform is None else self.pre_transform(data)
        data, slices = self.collate([data])
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self):
        return '{}{}()'.format(self.__class__.__name__, self.name.capitalize())

基本上和官网文档给的结构一致,部分处理细节稍微有点不同。其实看到这个地方的时候,我已经有点懵了,因为比如说代码中self.processed_paths[0]并没有被定义赋值,为什么可以直接调用。

这部分疑惑等后面得到解答之后再回来继续改笔记

猜你喜欢

转载自blog.csdn.net/wy_97/article/details/108547022