用于DataLoader的pytorch数据集

暂时介绍 image-mask型数据集, 以人手分割数据集 EGTEA Gaze+ 为例.

 

准备数据文件夹

需要将ImageMask分开存放, 对应文件的文件名必须保持一致. 提醒: Mask 图像一般为 png 单通道

EGTEA Gaze+ 数据集下载解压后即得到如下的目录, 无需处理

hand14k

┣━ Images

┃ ┣━ OP01-R01-PastaSalad_000014.jpg

┃ ┣━ OP01-R01-PastaSalad_000015.jpg

┃ ┣━ OP01-R01-PastaSalad_000016.jpg

┃ ┗━ ···

┗━ Masks

┣━ OP01-R01-PastaSalad_000014.png

┣━ OP01-R01-PastaSalad_000015.png

┣━ OP01-R01-PastaSalad_000016.png

┗━ ···

生成路径文件, 划分数据集

脚本如下:import cv2 as cv

import numpy as np

import PIL.Image as Image

import os

 

np.random.seed(42)

 

 

def split_dataset():

    # 读取图像文件

    images_path = "./Images/"

    images_list = os.listdir(images_path)  # 每次返回文件列表顺序不一致

    images_list.sort()  # 需要排序处理

 

    # 读取标签/Mask图像

    labels_path = "./Masks/"

    labels_list = os.listdir(labels_path)

    labels_list.sort()

 

    # 创建路径文件 (使用二进制编码, 避免操作系统不匹配)

    train_file = "./train.data"

    test_file = "./test.data"

    if os.path.isfile(train_file) and os.path.isfile(test_file):

        return

    train_file = open(train_file, "wb")

    test_file = open(test_file, "wb")

 

    # 外汇返佣

    split_ratio = 0.8

    for image, label in zip(images_list, labels_list):

        image = os.path.join(images_path, image)

        label = os.path.join(labels_path, label)

        if os.path.basename(image).split('.')[0] != os.path.basename(label).split('.')[0]:

            continue

        file = train_file if np.random.rand() < split_ratio else test_file

        file.write((image + "\t" + label + "\n").encode("utf-8"))

    train_file.close()

    test_file.close()

    print("成功划分数据集!")

 

 

def read_image(path):

    img = np.array(Image.open(path))

    if img.ndim == 2:

        img = cv.merge([img, img, img])

    return img

 

 

def test_read():

    train_file = "./test.data"

    with open(train_file, 'rb') as f:

        datalist = f.readlines()

    datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]

 

    item = datalist[np.random.randint(42)]

    image = read_image(item[0])

    mask = read_image(item[1])

    cv.imshow("image", image)

    cv.imshow("mask", mask)

    cv.waitKey(0)

    cv.destroyAllWindows()

 

 

if __name__ == '__main__':

    split_dataset()

    test_read()

派生 Dataset

class MyDataset(Dataset):

 

    def __init__(

        self, data_file, data_dir, transform_trn=None, transform_val=None

        ):

        """

        Args:

            data_file (string): Path to the data file with annotations.

            data_dir (string): Directory with all the images.

            transform_{trn, val} (callable, optional): Optional transform to be applied

                on a sample.

        """

        with open(data_file, 'rb') as f:

            datalist = f.readlines()

        self.datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]

        self.root_dir = data_dir

        self.transform_trn = transform_trn

        self.transform_val = transform_val

        self.stage = 'train'

 

    def set_stage(self, stage):

        self.stage = stage

 

    def __len__(self):

        return len(self.datalist)

 

    def __getitem__(self, idx):

        img_name = os.path.join(self.root_dir, self.datalist[idx][0])

        msk_name = os.path.join(self.root_dir, self.datalist[idx][1])

        def read_image(x):

            img_arr = np.array(Image.open(x))

            if len(img_arr.shape) == 2: # grayscale

                img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)

            return img_arr

        image = read_image(img_name)

        mask = np.array(Image.open(msk_name))

        if img_name != msk_name:

            assert len(mask.shape) == 2, 'Masks must be encoded without colourmap'

        sample = {'image': image, 'mask': mask}

        if self.stage == 'train':

            if self.transform_trn:

                sample = self.transform_trn(sample)

        elif self.stage == 'val':

            if self.transform_val:

                sample = self.transform_val(sample)

        return sample

构造DataLoader

# 定义Transform

composed_trn = transforms.Compose([ResizeShorterScale(shorter_side, low_scale, high_scale),

                                       Pad(crop_size, [123.675, 116.28, 103.53], ignore_label),

                                       RandomMirror(),

                                       RandomCrop(crop_size),

                                       Normalise(*normalise_params),

                                       ToTensor()])

composed_val = transforms.Compose([Normalise(*normalise_params),

                                       ToTensor()])

 

# 导入数据集

trainset = MyDataset(data_file=train_list,

                     data_dir=train_dir,

                     transform_trn=composed_trn,

                     transform_val=composed_val)

valset = MyDataset(data_file=val_list,

                   data_dir=val_dir,

                   transform_trn=None,

                   transform_val=composed_val)

 

# 构建生成器

train_loader = DataLoader(trainset,

                          batch_size=batch_size,

                          shuffle=True,

                          num_workers=num_workers,

                          pin_memory=True,

                          drop_last=True)

val_loader = DataLoader(valset,

                        batch_size=1,

                        shuffle=False,

                        num_workers=num_workers,

                        pin_memory=True)

训练

for i, sample in enumerate(train_loader):

    image = sample['image'].cuda()

    target = sample['mask'].cuda()

    image_var = torch.autograd.Variable(image).float()

    target_var = torch.autograd.Variable(target).long()

    # Compute output

    output = net(image_var)

    ...

 

原文链接:https://blog.csdn.net/Augurlee/article/details/103652444

猜你喜欢

转载自www.cnblogs.com/benming/p/12091155.html