Caltech-256 数据集处理(三) 训练集和验证集载入PyTorch Dateloader

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/tfcy694/article/details/84260766

Caltech-256数据集在PyTorch中的处理:
Caltech-256 数据集处理(一) label提取
Caltech-256 数据集处理(二) 训练集和测试集的制作
Caltech-256 数据集处理(三) 训练集和验证集载入Dateloader

  1. Caltech-256中的每张图片的大小都不一定,所以在这里需要进行crop操作。
  2. 这里偷懒了,mean和std去了imagenet的数据,严格来讲需要单独计算。
  3. rstrip()strip()可以根据具体场景灵活使用,这里保险起见多用了。
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

root='/media/this/02ff0572-4aa8-47c6-975d-16c3b8062013/'

def default_loader(path):
    return Image.open(path).convert('RGB')

class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

mean = [ 0.485, 0.456, 0.406 ]
std = [ 0.229, 0.224, 0.225 ]

transform = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean = mean, std = std),
    ])

train_data = MyDataset(txt=root+'dataset-trn.txt', transform=transform)
test_data = MyDataset(txt=root+'dataset-val.txt', transform=transform)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
'''
for idx, (data, target) in enumerate(test_loader):
    if(idx%10==0):
        print(str(idx)+' '+str(target))

for idx, (data, target) in enumerate(train_loader):
    if(idx%10==0):
        print(str(idx)+' '+str(target))
'''

猜你喜欢

转载自blog.csdn.net/tfcy694/article/details/84260766