前言
这一节笔记中主要针对继承Dataset,分次加载到内存,这种数据集一般很大,不适合一次性加载完毕,需要分批加载处理。
构建数据集
1、Dataset
pytorch geometric
构建数据集分两种:
1、继承InMemoryDataset
,一次性加载所有的数据到内存
2、继承Dataset
,分次加载到内存
Mini-Batching
:将一组样本组合成一个统一的表示形式,进行并行处理
2、官方文档例子
首先还是看下引入的库文件,对比一下InMemoryDataset
,这里我们引入的是Dataset
,对比一下这两个库,初始化的参数完全一致
主要是Dataset
多了len()
与get()
:
-
torch_geometric.data.Dataset.len()
: 返回数据集中的样本数 -
torch_geometric.data.Dataset.get()
: 实现加载单个图的逻辑
下面来看对比分析:
import os.path as osp # 调用系统路径
import torch
from torch_geometric.data import Dataset
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)# 对比InMemoryDataset 这块少了直接加载的代码,因为我们需要分次加载
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self): # 对比一次性加载,这里会有多个
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
def process(self):
i = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self): # 返回数目
return len(self.processed_file_names)
def get(self, idx): # 一个一个文件手动加载到内存中
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
3、process解读
下面是我从某博客中引用的一段话,个人觉得说的挺好的。process()
方法存在的意义是:
- 原始的格式可能是
csv
或者mat
,在process()
函数里可以转化为pt
格式的文件。 - 这样在
get()
方法中,就可以直接使用torch.load()
函数,读取pt
格式的文件,返回的是torch_geometric.data.Data
类型的数据。 - 而不用在
get()
方法里面做数据转换操作,比如说,把其他格式的数据转换为torch_geometric.data.Data
类型的数据。 - 当然我们也可以提前把数据转换为
torch_geometric.data.Data
类型,使用pt
格式保存在self.processed_dir
中。
Ps:上面这段话部分针对的是Dataset
,其实InMemoryDataset
也差不多,只不过最后不需要逐个加载进内存,而是直接加载进内存
扫描二维码关注公众号,回复:
11675286 查看本文章
MINI-BATCHING
官方文档地址:
https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html#pairs-of-graphs
我觉得这里我没有理解,视频的解释实在是太少了,之后再补