【pytorch】训练集的读取

pytorch读取训练集是非常便捷的,只需要使用到2个类:

(1)torch.utils.data.Dataset

(2)torch.utils.data.DataLoader


常用数据集的读取

1、torchvision.datasets的使用

对于常用数据集,可以使用torchvision.datasets直接进行读取。torchvision.dataset是torch.utils.data.Dataset的实现

该包提供了以下数据集的读取

  • MNIST
  • COCO (Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10

下面以cifar10为例:

[python] view plain copy
print ?
  1. import torch  
  2. import torchvision  
  3. from PIL import Image  
  4.   
  5. cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)  
  6. print(cifarSet[0])  
  7. img, label = cifarSet[0]  
  8. print (img)  
  9. print (label)  
  10. print (img.format, img.size, img.mode)  
  11. img.show()  
import torch
import torchvision
from PIL import Image

cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)
print(cifarSet[0])
img, label = cifarSet[0]
print (img)
print (label)
print (img.format, img.size, img.mode)
img.show()

2、实例化torch.utils.data.DataLoader

[python] view plain copy
print ?
  1. mytransform = transforms.Compose([  
  2.     transforms.ToTensor()  
  3.     ]  
  4. )  
  5.   
  6. # torch.utils.data.DataLoader  
  7. cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )  
  8. cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)  
mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

# torch.utils.data.DataLoader
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)

下面就可以进行读取数据的显示,以进行简单测试是否读取成功:

[python] view plain copy
print ?
  1. for i, data in enumerate(cifarLoader, 0):  
  2.     print(data[i][0])  
  3.     # PIL  
  4.     img = transforms.ToPILImage()(data[i][0])  
  5.     img.show()  
  6.     break  
for i, data in enumerate(cifarLoader, 0):
    print(data[i][0])
    # PIL
    img = transforms.ToPILImage()(data[i][0])
    img.show()
    break


自定义标签数据集的读取

1、实现torch.utils.data.Dataset

假设我们有一个标签test_images.txt,内容如下:


对应的图像位于images目录下。

首先要继承torch.utils.data.Dataset类,完成图像及标签的读取。

[python] view plain copy
print ?
  1. import os  
  2. import torch  
  3. import torch.utils.data as data  
  4. from PIL import Image  
  5.   
  6. def default_loader(path):  
  7.     return Image.open(path).convert('RGB')  
  8.   
  9. class myImageFloder(data.Dataset):  
  10.     def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):  
  11.         fh = open(label)  
  12.         c=0  
  13.         imgs=[]  
  14.         class_names=[]  
  15.         for line in  fh.readlines():  
  16.             if c==0:  
  17.                 class_names=[n.strip() for n in line.rstrip().split('   ')]  
  18.             else:  
  19.                 cls = line.split()   
  20.                 fn = cls.pop(0)  
  21.                 if os.path.isfile(os.path.join(root, fn)):  
  22.                     imgs.append((fn, tuple([float(v) for v in cls])))  
  23.             c=c+1  
  24.         self.root = root  
  25.         self.imgs = imgs  
  26.         self.classes = class_names  
  27.         self.transform = transform  
  28.         self.target_transform = target_transform  
  29.         self.loader = loader  
  30.   
  31.     def __getitem__(self, index):  
  32.         fn, label = self.imgs[index]  
  33.         img = self.loader(os.path.join(self.root, fn))  
  34.         if self.transform is not None:  
  35.             img = self.transform(img)  
  36.         return img, torch.Tensor(label)  
  37.   
  38.     def __len__(self):  
  39.         return len(self.imgs)  
  40.       
  41.     def getName(self):  
  42.         return self.classes  
import os
import torch
import torch.utils.data as data
from PIL import Image

def default_loader(path):
    return Image.open(path).convert('RGB')

class myImageFloder(data.Dataset):
    def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
        fh = open(label)
        c=0
        imgs=[]
        class_names=[]
        for line in  fh.readlines():
            if c==0:
                class_names=[n.strip() for n in line.rstrip().split('	')]
            else:
                cls = line.split() 
                fn = cls.pop(0)
                if os.path.isfile(os.path.join(root, fn)):
                    imgs.append((fn, tuple([float(v) for v in cls])))
            c=c+1
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.Tensor(label)

    def __len__(self):
        return len(self.imgs)
    
    def getName(self):
        return self.classes

2、实例化torch.utils.data.DataLoader

[python] view plain copy
print ?
  1. mytransform = transforms.Compose([  
  2.     transforms.ToTensor()  
  3.     ]  
  4. )  
  5.   
  6. # torch.utils.data.DataLoader  
  7. imgLoader = torch.utils.data.DataLoader(  
  8.          myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ),   
  9.          batch_size= 2, shuffle= False, num_workers= 2)  
  10.   
  11. for i, data in enumerate(imgLoader, 0):  
  12.     print(data[i][0])  
  13.     # opencv  
  14.     img2 = data[i][0].numpy()*255  
  15.     img2 = img2.astype('uint8')  
  16.     img2 = np.transpose(img2, (1,2,0))  
  17.     img2=img2[:,:,::-1]#RGB->BGR  
  18.     cv2.imshow('img2', img2)  
  19.     cv2.waitKey()  
  20.     break  
mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

# torch.utils.data.DataLoader
imgLoader = torch.utils.data.DataLoader(
         myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), 
         batch_size= 2, shuffle= False, num_workers= 2)

for i, data in enumerate(imgLoader, 0):
    print(data[i][0])
    # opencv
    img2 = data[i][0].numpy()*255
    img2 = img2.astype('uint8')
    img2 = np.transpose(img2, (1,2,0))
    img2=img2[:,:,::-1]#RGB->BGR
    cv2.imshow('img2', img2)
    cv2.waitKey()
    break

相关代码可以查看:tfygg/pytorch-tutorials


---------------------------------------------------------------------------------------------------

在各方小伙伴的努力和支持下,pytorch中文文档 第一版终于上线啦!!!(鼓掌)文档还有很多小瑕疵,但是大体可以放心使用了~我们遵循快速迭代的原则,所以赶紧上线第一版来接受广大开源社区的意见和建议。欢迎加入我们

猜你喜欢

转载自blog.csdn.net/zhuiqiuk/article/details/80297587