pytorch读取训练集是非常便捷的,只需要使用到2个类:
(1)torch.utils.data.Dataset(2)torch.utils.data.DataLoader
常用数据集的读取
1、torchvision.datasets的使用
对于常用数据集,可以使用torchvision.datasets直接进行读取。torchvision.dataset是torch.utils.data.Dataset的实现该包提供了以下数据集的读取
- MNIST
- COCO (Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
下面以cifar10为例:
- import torch
- import torchvision
- from PIL import Image
- cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)
- print(cifarSet[0])
- img, label = cifarSet[0]
- print (img)
- print (label)
- print (img.format, img.size, img.mode)
- img.show()
import torch import torchvision from PIL import Image cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True) print(cifarSet[0]) img, label = cifarSet[0] print (img) print (label) print (img.format, img.size, img.mode) img.show()
2、实例化torch.utils.data.DataLoader
- mytransform = transforms.Compose([
- transforms.ToTensor()
- ]
- )
- # torch.utils.data.DataLoader
- cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )
- cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)
mytransform = transforms.Compose([ transforms.ToTensor() ] ) # torch.utils.data.DataLoader cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform ) cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)
下面就可以进行读取数据的显示,以进行简单测试是否读取成功:
- for i, data in enumerate(cifarLoader, 0):
- print(data[i][0])
- # PIL
- img = transforms.ToPILImage()(data[i][0])
- img.show()
- break
for i, data in enumerate(cifarLoader, 0): print(data[i][0]) # PIL img = transforms.ToPILImage()(data[i][0]) img.show() break
自定义标签数据集的读取
1、实现torch.utils.data.Dataset
假设我们有一个标签test_images.txt,内容如下:
对应的图像位于images目录下。
首先要继承torch.utils.data.Dataset类,完成图像及标签的读取。
- import os
- import torch
- import torch.utils.data as data
- from PIL import Image
- def default_loader(path):
- return Image.open(path).convert('RGB')
- class myImageFloder(data.Dataset):
- def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
- fh = open(label)
- c=0
- imgs=[]
- class_names=[]
- for line in fh.readlines():
- if c==0:
- class_names=[n.strip() for n in line.rstrip().split(' ')]
- else:
- cls = line.split()
- fn = cls.pop(0)
- if os.path.isfile(os.path.join(root, fn)):
- imgs.append((fn, tuple([float(v) for v in cls])))
- c=c+1
- self.root = root
- self.imgs = imgs
- self.classes = class_names
- self.transform = transform
- self.target_transform = target_transform
- self.loader = loader
- def __getitem__(self, index):
- fn, label = self.imgs[index]
- img = self.loader(os.path.join(self.root, fn))
- if self.transform is not None:
- img = self.transform(img)
- return img, torch.Tensor(label)
- def __len__(self):
- return len(self.imgs)
- def getName(self):
- return self.classes
import os import torch import torch.utils.data as data from PIL import Image def default_loader(path): return Image.open(path).convert('RGB') class myImageFloder(data.Dataset): def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader): fh = open(label) c=0 imgs=[] class_names=[] for line in fh.readlines(): if c==0: class_names=[n.strip() for n in line.rstrip().split(' ')] else: cls = line.split() fn = cls.pop(0) if os.path.isfile(os.path.join(root, fn)): imgs.append((fn, tuple([float(v) for v in cls]))) c=c+1 self.root = root self.imgs = imgs self.classes = class_names self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(os.path.join(self.root, fn)) if self.transform is not None: img = self.transform(img) return img, torch.Tensor(label) def __len__(self): return len(self.imgs) def getName(self): return self.classes
2、实例化torch.utils.data.DataLoader
- mytransform = transforms.Compose([
- transforms.ToTensor()
- ]
- )
- # torch.utils.data.DataLoader
- imgLoader = torch.utils.data.DataLoader(
- myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ),
- batch_size= 2, shuffle= False, num_workers= 2)
- for i, data in enumerate(imgLoader, 0):
- print(data[i][0])
- # opencv
- img2 = data[i][0].numpy()*255
- img2 = img2.astype('uint8')
- img2 = np.transpose(img2, (1,2,0))
- img2=img2[:,:,::-1]#RGB->BGR
- cv2.imshow('img2', img2)
- cv2.waitKey()
- break
mytransform = transforms.Compose([ transforms.ToTensor() ] ) # torch.utils.data.DataLoader imgLoader = torch.utils.data.DataLoader( myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), batch_size= 2, shuffle= False, num_workers= 2) for i, data in enumerate(imgLoader, 0): print(data[i][0]) # opencv img2 = data[i][0].numpy()*255 img2 = img2.astype('uint8') img2 = np.transpose(img2, (1,2,0)) img2=img2[:,:,::-1]#RGB->BGR cv2.imshow('img2', img2) cv2.waitKey() break
相关代码可以查看:tfygg/pytorch-tutorials
---------------------------------------------------------------------------------------------------
在各方小伙伴的努力和支持下,pytorch中文文档 第一版终于上线啦!!!(鼓掌)文档还有很多小瑕疵,但是大体可以放心使用了~我们遵循快速迭代的原则,所以赶紧上线第一版来接受广大开源社区的意见和建议。欢迎加入我们