通用数据加载器
官方给出的,可以不局限于给定的数据集,加载自己的数据集。
CLASS torchvision.datasets.DatasetFolder(
root: str,
loader: Callable[[str], Any],
extensions: Union[Tuple[str, ...], NoneType] = None,
transform: Union[Callable, NoneType] = None,
target_transform: Union[Callable, NoneType] = None,
is_valid_file: Union[Callable[[str], bool], NoneType] = None
) → None
参数含义:
- root(string)–根目录路径。
- loader(callable)–在给定路径的情况下加载样本的函数。
- extensions(tuple [string])–允许的扩展名列表。 扩展名和is_valid_file不应同时传递。
- transform (callable, optional)–接收样本并返回转换版本的函数/转换。 例如对图像进行transforms.RandomCrop。
- target_transform(callable, optional)–接收目标并对其进行转换的函数/转换。
- is_valid_file –接受文件路径并检查文件是否为有效文件(用于检查损坏的文件)的函数,不应同时传递扩展名和is_valid_file。
文件夹组织:
- 应以如下结构组织文件
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext
通用图像数据加载器
官方给出的,可以不局限于给定的图像数据集,加载自己的图像数据集。
CLASS torchvision.datasets.ImageFolder(
root: str,
transform: Union[Callable, NoneType] = None,
target_transform: Union[Callable, NoneType] = None,
loader: Callable[[str], Any] = <function default_loader>,
is_valid_file: Union[Callable[[str], bool], NoneType] = None
)
参数含义:
- root(string)–根目录路径。
- transform (callable, optional)–接收PIL图像并返回转换版本的函数/转换。 例如对图像进行transforms.RandomCrop。
- target_transform(callable, optional)–接收目标并对其进行转换的函数/转换。
- loader(callable)–在给定路径的情况下加载样本的函数。
- is_valid_file –接受文件路径并检查文件是否为有效文件(用于检查损坏的文件)的函数。
文件夹组织:
- 应以如下结构组织文件
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
附:以上类的源代码
自己定义的数据加载器
- 常规思路
# 读取文件位置
def get_path('path-str'):
...
return file_path
# 读取图片
def loader_img(file_path):
# 根据图片的位置读取图片并返回读取的图片和标签
# 对图片进行处理
...
return imgs_list, label_list
# 获取batchsize大小的数据
def get_train_data(imgs_list,label_list,batchsize):
...
return img[1],img[2],...,img[batchsize]
常规思路无法将加载出来的数据集使用pytorch的DataLoader加载,无法以batch的形式去训练,故可以按照pytorch中的Dataset类,写一个自己的类。
- 模仿pytorch中的类,定义自己的类
class MyDataset(torch.utils.data.Dataset): # 需要继承torch.utils.data.Dataset
def __init__(self):
# 初始化文件路径或文件名列表。
# 初始化该类的一些基本参数。
pass
def __getitem__(self, index):
# TODO
#1.从文件中读取一个数据(例如,plt.imread)。
#2.预处理数据(例如torchvision.Transform)。
#3.返回数据对(例如图像和标签)。
# 这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# 返回数据集的总大小。
这种方法中,标签信息一般在文件名中,可以使用函数将文件名的标签信息存储起来。
- 一种更加具有鲁棒性的类的实现
gen_txt函数
# coding:utf-8
import os
'''
为数据集生成对应的txt文件
'''
train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")
valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
if __name__ == '__main__':
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
MyDataset类
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
方法就是先使用gen_txt函数
生成数据集的txt文档,在MyDataset类
中使用,生成DataLoader可以直接加载的数据集。
代码来源:Pytorch模型训练实用教程
- 标签为图像的情况
这个问题困扰了我很长时间,很感谢好心人,直接附上链接:Pytorch 构建自己的数据集 输入与标签皆为图片