参考:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
本文是上面视频的笔记,up主讲的特别详细,推荐观看。
在pytorch中加载数据主要涉及到两个类:Dataset 和 Dataloader
Dataset :提供一种方式去提取数据并得到label
Dataset:对数据进行打包送到网络中去,为后面的网络提供不同的数据形式。
下面是代码及说明:
from torch.utils. data import Dataset
可看到说明,Dataset是一个抽象类,我们重写Dataset时要继承这个类,所有的子类都应该重写__getitem__()方法,这个方法作用是获取数据及对应的labe。同时我们可以选择性地去重写__len__方法,其作用是获取数据集长度。
例子:
这里我使用的是猫狗二分类的数据集,如图:
from torch.utils. data import Dataset
from PIL import Image
import os
class Mydataset(Dataset):
def __init__(self,root_dir, label_dir):
self.root_dir = root_dir ##根目录
self.label_dir = label_dir ##标签,也就是文件名
self.path = os.path.join(self.root_dir,self.label_dir) ##拼成一个完整的目录
self.img_path = os.listdir(self.path) ##获得图片的一个list
def __getitem__(self, idx):
img_name = self.img_path[idx] ##得到单个图片的名字
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) ##得到单个图片的路径
img = Image.open(img_item_path) ##图片数据
label = self.label_dir ##标签
return img, label
def __len__(self):
return len(self.img_path)
root_dir="D:/猫狗大战/data/train"
cat_label_dir = "cat"
dog_label_dir = "dog"
cat_dataset = Mydataset(root_dir,cat_label_dir)
dog_dataset = Mydataset(root_dir,dog_label_dir)
img, label = cat_dataset[1]
img.show()
print(label)
img, label = dog_dataset[1]
img.show()
print(label)
输出结果:
cat
dog
写给自己,另外,可以参考这篇博客:
https://ptorch.com/news/215.html
fastai也可以关注以下