pytorch数据预处理

一,数据加载

数据路径:

#coding:utf-8
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

class DogCat(data.Dataset):
    def __init__(self, path):
        imgs = os.listdir(path)
        # 所有图片的绝对路径
        # 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
        self.imgs_list_path = [os.path.join(path, i) for i in imgs]

    def __getitem__(self, index):
        img_path = self.imgs_list_path[index]
        # dog->1, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img = Image.open(img_path)
        array = np.asarray(pil_img)
        img = t.from_numpy(array)
        return img_path,img, label

    def __len__(self):
        return len(self.imgs_list_path)
if __name__ == '__main__':
    dataset = DogCat('./data/dogcat/')
    # img, label = dataset[0]  # 相当于调用dataset.__getitem__(0)
    print('len(dataset)=',len(dataset))
    for img_path,img, label in dataset:
        print(img_path,img.size(), img.float().mean(), label)

打印结果:

二,数据归一化 

PyTorch提供了torchvision1。它是一个视觉工具包,提供了很多视觉图像处理的工具,其中transforms模块提供了对PIL Image对象和Tensor对象的常用操作。

对PIL Image的操作包括:

  • Scale:调整图片尺寸,长宽比保持不变
  • CenterCropRandomCropRandomResizedCrop: 裁剪图片
  • Pad:填充
  • ToTensor:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]

对Tensor的操作包括:

  • Normalize:标准化,即减均值,除以标准差
  • ToPILImage:将Tensor转为PIL Image对象
#coding:utf-8
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
    transforms.CenterCrop(224), # 从图片中间切出224*224的图片
    transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
    #input[channel] = (input[channel] - mean[channel]) / std[channel]
])

class DogCat(data.Dataset):
    def __init__(self, path,transforms=None):
        imgs = os.listdir(path)
        # 所有图片的绝对路径
        # 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
        self.imgs_list_path = [os.path.join(path, i) for i in imgs]
        self.transforms=transforms

    def __getitem__(self, index):
        img_path = self.imgs_list_path[index]
        # dog->1, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img = Image.open(img_path)
        if self.transforms:
            pil_img=self.transforms(pil_img)
        array = np.asarray(pil_img)
        img = t.from_numpy(array)
        return img_path,img, label

    def __len__(self):
        return len(self.imgs_list_path)
if __name__ == '__main__':
    dataset = DogCat('./data/dogcat/',transforms=transform)
    # img, label = dataset[0]  # 相当于调用dataset.__getitem__(0)
    print('len(dataset)=',len(dataset))
    for img_path,img, label in dataset:
        print(img_path,img.size(), img.float().mean(), label)

猜你喜欢

转载自blog.csdn.net/fanzonghao/article/details/86543538