Pytorch入门学习:数据加载和预处理的通用方法

转载来源:CSDN
原文:https://blog.csdn.net/Hungryof/article/details/76649006

torchvision的主要用途。

两种数据集:

  1. 所有图片都在同一个文件夹内。(这个用 torch.utils.data.DataSet类就行!)
  2. 不同类别的图片放在不同的文件夹。(用 torchvision.datasets.ImageFolder(‘image_dir_root’ )

大部分任务的数据都是第一种吧,第二种一般是分类任务,比如imagenet数据集有1000类,对应1000个文件夹。

目录结构如下:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png

.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

注意

torchvision包的三个用途

  1. 提供流行的model,同时可以针对常用数据集直接进行处理。
  2. 还针对torch.utils.data.Dataset进行了扩充,主要就是有了针对这种不同类别图片放入不同文件夹的数据进行读取,torchvision.datasets.ImageFolder是torch.utils.data.Dataset的子类!都返回一个迭代器。
  3. 提供现成的torchvision.transforms ,从而避免自己写的麻烦。

两种读取方法
一般用到:

  1. torch.utils.data.Dataset(这是底层的),或是继承自它的自定义类,或是继承自它的 torchvision.data.ImageFolder.
  2. 对于1读取的图片,进行 torchvison.transforms来变换一下。
  3. 对于2返回的迭代器,用 torch.utils.data.DataLoader用多线程读取。

读取流程示意

  1. 自定义dataset类, 它是最底层的。重载 torch.utils.data.Dataset。至少重载三个函数:
    init, getitem__以及__len.
    这个主要负责从数据库中读取图片,但是我们读取的图片可能要经过各种变换,放缩之类的。所以在_init__中可以把变换操作名称传入,在_getitem 中先load图片,然后在img_transformed = self.transforms(img)。其中self.transforms是__init__传入的参数。

  2. 将torchvision.transforms.Compose函数作为参数,往自定义dataset类里面传

  3. 将2返回的迭代器,用 torch.utils.data.DataLoader多线程读取

使用 torch.utils.data.Dataset针对 All images in One Folder

以官方例子 super_resolution为例:
首先在main中

train_set = get_training_set(opt.upscale_factor)
test_set = get_test_set(opt.upscale_factor)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)

然后看 get_training_set,追踪到data.py,该脚本主要是对数据进行下载解压,以及

from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale

from dataset import DatasetFromFolder


def download_bsd300(dest="dataset"):
    output_image_dir = join(dest, "BSDS300/images")

    if not exists(output_image_dir):
        makedirs(dest)
        url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
        print("downloading url ", url)

        data = urllib.request.urlopen(url)

        file_path = join(dest, basename(url))
        with open(file_path, 'wb') as f:
            f.write(data.read())

        print("Extracting data")
        with tarfile.open(file_path) as tar:
            for item in tar:
                tar.extract(item, dest)

        remove(file_path)

    return output_image_dir


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def input_transform(crop_size, upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Scale(crop_size // upscale_factor),
        ToTensor(),
    ])


def target_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor(),
    ])

看到这里开始调用自定义dataset类!

def get_training_set(upscale_factor):
    root_dir = download_bsd300()
    train_dir = join(root_dir, "train")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

自定义dataset类,传入参数是 transforms。可以看到这是将函数input_transform作为

参数传进自定义类。

return DatasetFromFolder(train_dir,
                         input_transform=input_transform(crop_size, upscale_factor),
                         target_transform=target_transform(crop_size))
def get_test_set(upscale_factor):
    root_dir = download_bsd300()
    test_dir = join(root_dir, "test")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))

再找到 dataset.py, 这里开始自定义dataset类。

import torch.utils.data as data

from os import listdir
from os.path import join
from PIL import Image


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y


class DatasetFromFolder(data.Dataset):

    def __init__(self, image_dir, input_transform=None, target_transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

        self.input_transform = input_transform
        self.target_transform = target_transform

在__getitem__中加载图片,并且将传入的transformation操作运用到

加载的图片中。 input = self.input_transforms(input)

这里的 self.input_transforms就是传入的"类的实例",由于类是callable的

所以可以 "类的实例(参数)"这样调用。在上一篇博客说到了这个。

def __getitem__(self, index):
    input = load_img(self.image_filenames[index])
    target = input.copy()
    if self.input_transform:
        input = self.input_transform(input)
    if self.target_transform:
        target = self.target_transform(target)

    return input, target

def __len__(self):
    return len(self.image_filenames)

看看torchvision.data.MNIST内部

class MNIST(data.Dataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

可以看到,这里也是用 img = self.transform(img)方式的。

def __getitem__(self, index):
    """
    Args:
        index (int): Index
    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    if self.train:
        img, target = self.train_data[index], self.train_labels[index]
    else:
        img, target = self.test_data[index], self.test_labels[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

def __len__(self):
    if self.train:
        return len(self.train_data)
    else:
        return len(self.test_data)

使用 torchvision.data.ImageFolder针对 One kind of images in One kind of Folder
比如imagenet的代码:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

第一,二步

用ImageFolder来读取dataset

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

第三步

DataLoader多线程读取

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    num_workers=args.workers, pin_memory=True, sampler=train_sampler)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)

猜你喜欢

转载自blog.csdn.net/qq_18649781/article/details/89281981