[Pytorch]将自己的数据集载入dataloader

一、概述

        初始化DataLoader类时必须注入一个参数dataset,而dataset为自己定义。DataSet类可以继承,但是必须重载__len__()__getitem__

        使用Pytoch封装的DataLoader有以下好处:

                ①可以自动实现多进程加载

                ②自动惰性加载,不会占用过多内存

                ③封装有数据预处理和数据增强等操作,避免重复造轮子

二、自定义DataSet

        以Faster R-CNN为例,一般建议至少传入以下参数,方便后续使用:

class FRCNNDataset(Dataset):
    def __init__(self, annotation_lines, input_shape = [600, 600], train = True):
        self.annotation_lines   = annotation_lines        #数据集列表
        self.length             = len(annotation_lines)   #数据集大小
        self.input_shape        = input_shape             #输出尺寸
        self.train              = train                   #是否训练

        然后重载__len__()__getitem__

def __len__(self):
    return self.length    #直接返回长度
def __getitem__(self, index):
    index = index % self.length

    #训练时候对数据进行随机增强,但验证时不进行
    image, y = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)
    #将图片转换成矩阵
    image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
    
    #编码先验框
    box_data = np.zeros((len(y), 5))
    if len(y) > 0:
        box_data[:len(y)] = y

    box = box_data[:, :4]
    label = box_data[:, -1]
    return image, box, label

        关于数据增强函数get_random_data(),其中还包含了对图片的无变形缩放功能

def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
    # 数据经过处理后格式为:地址——(空格)——预测框,使用split函数即可切割出地址和先验框
    line = annotation_line.split()
    # 读取图像并转换为RGB格式
    image = Image.open(line[0])
    image = cvtColor(image)
    # 获得图像的高宽与目标高宽
    iw, ih = image.size
    h, w = input_shape
    # 读取先验框
    box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])

                仅缩放的无变形缩放功(非训练模式)

# 在不进行随机数据增强的情况下(非训练模式),直接变形后输出
if not random:
    #获取变形比例
    scale = min(w/iw, h/ih)
    nw = int(iw*scale)
    nh = int(ih*scale)
    dx = (w-nw)//2
    dy = (h-nh)//2
    #   将图像多余的部分加上灰条
    image       = image.resize((nw,nh), Image.BICUBIC)
    new_image   = Image.new('RGB', (w,h), (128,128,128))
    new_image.paste(image, (dx, dy))
    image_data  = np.array(new_image, np.float32)
    #   对真实框进行调整
    if len(box)>0:
        np.random.shuffle(box)
        box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
        box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
        box[:, 0:2][box[:, 0:2]<0] = 0
        box[:, 2][box[:, 2]>w] = w
        box[:, 3][box[:, 3]>h] = h
        box_w = box[:, 2] - box[:, 0]
        box_h = box[:, 3] - box[:, 1]
        box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
    #返回图片和先验框
    return image_data, box

                带数据增强的无变形缩放(训练模式)

        #   对图像进行缩放并且进行长和宽的扭曲
        new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
        scale = self.rand(.25, 2)
        if new_ar < 1:
            nh = int(scale*h)
            nw = int(nh*new_ar)
        else:
            nw = int(scale*w)
            nh = int(nw/new_ar)
        image = image.resize((nw,nh), Image.BICUBIC)

        #   将图像多余的部分加上灰条
        dx = int(self.rand(0, w-nw))
        dy = int(self.rand(0, h-nh))
        new_image = Image.new('RGB', (w,h), (128,128,128))
        new_image.paste(image, (dx, dy))
        image = new_image

        #   翻转图像
        flip = self.rand()<.5
        if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)

        image_data      = np.array(image, np.uint8)

        #   对图像进行色域变换
        #   计算色域变换的参数
        r               = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1

        #   将图像转到HSV上
        hue, sat, val   = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
        dtype           = image_data.dtype

        #   应用变换
        x       = np.arange(0, 256, dtype=r.dtype)
        lut_hue = ((x * r[0]) % 180).astype(dtype)
        lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
        lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

        image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
        image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)

        #   对真实框进行调整
        if len(box)>0:
            np.random.shuffle(box)
            box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
            box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
            if flip: box[:, [0,2]] = w - box[:, [2,0]]
            box[:, 0:2][box[:, 0:2]<0] = 0
            box[:, 2][box[:, 2]>w] = w
            box[:, 3][box[:, 3]>h] = h
            box_w = box[:, 2] - box[:, 0]
            box_h = box[:, 3] - box[:, 1]
            box = box[np.logical_and(box_w>1, box_h>1)] 
        
        return image_data, box

                关于collate_fn参数

                        __getitem__一般返回(image,label)样本对,而DataLoader需要一个batch_size用于处理batch样本,以便于批量训练。

                        默认的default_collate(batch)函数仅能对尺寸一致且batch_size相同的image进行整理,如将(img0,lbl0),(img1,lbl1),(img2,lbl2)整合为([img0,img1,img2],[lbl0,lbl1,lbl2]),如图像中含有box等参数则需要自定义处理

def frcnn_dataset_collate(batch):
    images = []
    bboxes = []
    labels = []
    for img, box, label in batch:
        images.append(img)
        bboxes.append(box)
        labels.append(label)
    images = torch.from_numpy(np.array(images))
    return images, bboxes, labels

三、语义分割与目标检测DataSet的区别

        ①在__getitem__中不需要获取box值,转而获取标志图png

    def __getitem__(self, index):
        annotation_line = self.annotation_lines[index]
        name            = annotation_line.split()[0]

        #   从文件中读取图像
        jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))
        png         = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))

        #   数据增强
        jpg, png    = self.get_random_data(jpg, png, self.input_shape, random = self.train)

        jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
        png         = np.array(png)
        png[png >= self.num_classes] = self.num_classes

        #   转化成one_hot的形式
        #   在这里需要+1是因为voc数据集有些标签具有白边部分
        seg_labels  = np.eye(self.num_classes + 1)[png.reshape([-1])]
        seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))

        return jpg, png, seg_labels

        ②get_random_data变形时需要对两张图做同样的变换

        if not random:
            iw, ih  = image.size
            scale   = min(w/iw, h/ih)
            nw      = int(iw*scale)
            nh      = int(ih*scale)

            image       = image.resize((nw,nh), Image.BICUBIC)
            new_image   = Image.new('RGB', [w, h], (128,128,128))
            new_image.paste(image, ((w-nw)//2, (h-nh)//2))

            label       = label.resize((nw,nh), Image.NEAREST)
            new_label   = Image.new('L', [w, h], (0))
            new_label.paste(label, ((w-nw)//2, (h-nh)//2))
            return new_image, new_label

        ③collate_fn需要进行修改

def deeplab_dataset_collate(batch):
    images      = []
    pngs        = []
    seg_labels  = []
    for img, png, labels in batch:
        images.append(img)
        pngs.append(png)
        seg_labels.append(labels)
    images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
    pngs        = torch.from_numpy(np.array(pngs)).long()
    seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
    return images, pngs, seg_labels

四、在训练过程中的调用

         ①读取文件集(经处理的txt文件)

with open(train_annotation_path, encoding='utf-8') as f:
    train_lines = f.readlines()
with open(val_annotation_path, encoding='utf-8') as f:
    val_lines   = f.readlines()
#获取数据集长度
num_train   = len(train_lines)
num_val     = len(val_lines)  

        ②检查数据集是否符合要求

                这里一般检查数据集是否足够大,也可不检查

        ③将数据集装入DataSet中

train_dataset   = MyDataset(train_lines, input_shape, anchors, batch_size, num_classes, train = True)
val_dataset     = MyDataset(val_lines, input_shape, anchors, batch_size, num_classes, train = False)

        ④将DataSet放入DataLoader中

                关于dataloader:一般有以下5个参数:

                        1.dataset:数据集对象,dataset型

                        2.batch_size:批大小,int型

                        3.shuffe:每一轮epoch是否重新洗牌,bool型

                        4.num_workers:多进程读取

                        5.drop_last:当样本不能被batch_size取整时,是否丢弃最后一批数据,bool型

gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
                                    drop_last=True, collate_fn=ssd_dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset  , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 
                                    drop_last=True, collate_fn=ssd_dataset_collate, sampler=val_sampler)

猜你喜欢

转载自blog.csdn.net/weixin_37878740/article/details/128711790