Pytorch教程入门系列6----数据处理

系列文章目录



前言

相信大家通过前几篇博文的阅读已经能看懂、了解Pytorch中主要API的作用和使用了,今天开始我们另起一锅,来看看Pytorch是如何勾画一副神经网络图。


一、神经网络的流程

需要声明的是,我们所介绍的是神经网络在计算机视觉方面,具体是物体检测方面的应用,故整个流程基本遵从一下几点:

数据处理(主要是图片视频等数据)
模型搭建
损失计算
模型训练
模型部署

二、为什么要数据处理

对深度学习来说,就是从海量的数据中去预测未知数据,因此数据是很重要的,数据处理的好坏也决定了深度学习的上限,因此数据处理是必要的!!!

三、拿来主义的公开数据集

一般数据处理都做成数据集,里面包含数据的image(原图)、label(标签)、annotation(标注)等
公认的三大数据集:ImageNet数据集、PASCAL VOC数据集、COCO数据集
随着自动驾驶领域的快速发展,也出现了众多自动驾驶领域的数据集,如KITTI、Cityscape和Udacity等

三、数据处理的流程

from torchvision import datasets, transforms, models

1.制作数据集 Dataset

from torch.utils.data import Dataset
代码如下(示例):

	# 建立data类,即可以方便地进行数据集的迭代
    class my_data(Dataset):
          def __init__(self, image_path, annotation_path, transform):
              # 初始化,读取数据集
          def __len__(self):
              # 获取数据集的总大小
          def __getitem__(self, id):
              # 对于指定的id,读取该数据并返回
    # 实例化
    dataset = my_data("your image path", "your annotation path",data_transforms) # 实例化该类
    dataset.classes #数据集包含种类名
    #迭代获取每一组数据
    for data in dataset:
        print(data)

内部的三个魔法函数的意义
init

在__init__函数中读取你的数据集文件,可以使用各种方法(如 pandas、numpy 等)进行读取。你也可以在这里进行一些预处理,如将图像调整到统一的大小。

len

在__len__函数中,你需要返回数据集的总大小,这样 PyTorch 就可以通过 len(dataset) 获取数据集的大小。

getitem

在__getitem__函数中,你需要根据给定的索引 id 读取数据,并返回一个包含所有数据的元组,如 (image, label)。

csv文件
import pandas as pd

class my_data(Dataset):
    def __init__(self, image_path, annotation_path, transform):
        # 读取数据集文件
        self.images = pd.read_csv(image_path)
        self.annotations = pd.read_csv(annotation_path)
        # 初始化 transform
        self.transform = transform

    def __len__(self):
        # 返回数据集大小
        return len(self.images)

    def __getitem__(self, id):
        # 读取图像和标签
        image = self.images.iloc[id, :]
        annotation = self.annotations.iloc[id, :]
        # 对图像进行 transform
        image = self.transform(image)
        # 返回数据
        return (image, annotation)
文件类
import os

class my_data(Dataset):
    def __init__(self, image_folder, annotation_folder, transform):
        # 读取数据集文件
        self.image_filenames = os.listdir(image_folder)
        self.annotation_filenames = os.listdir(annotation_folder)
        # 初始化 transform
        self.transform = transform

    def __len__(self):
        # 返回数据集大小
        return len(self.image_filenames)

    def __getitem__(self, id):
        # 读取图像和标签
        image_filename = self.image_filenames[id]
        annotation_filename = self.annotation_filenames[id]
        # 使用 os.path.join 来组合路径和文件名
        image_path = os.path.join('data', 'images', image_filename)
        annotation_path = os.path.join('data', 'annotations', annotation_filename)
        # 读取文件
        image = read_file(image_path)
        annotation = read_file(annotation_path)
        # 对图像进行 transform
        image = self.transform(image)
        # 返回数据
        return (image, annotation)

2.数据增强torchvision.transforms

注意:我们个人去拿一个具体场景的图片资料训练神经网络,首先一个突出问题就是,数据量不够。较少的数据一、不满足神经网络训练的海量要求。二、不具备代表性。三、数据集中的图片有可能存在大小不一的情况,并且原始图片像素RGB值较大(0~255),这些都不利于神经网络的训练收敛,因此还需要进行一些图像变换工作。为此数据增强是非常必要的

数据增强的方法主要有以下几种:
图片缩放、旋转、遮挡、裁剪、翻转等
transforms.Compose()用来把一系列的增强操作集合起来按顺序执行

data_transforms = transforms.Compose([
        transforms.Resize([96, 96]), #缩放
        transforms.RandomRotation(45),#随机旋转,-4545度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        Cutout(0.4),#随机遮挡的概率
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),#矩阵转换为张量必须的数据转换
        transforms.Normalize([0.535, 0.473, 0.572], [0.189, 0.276, 0.209])#均值,标准差进行归一化
    ])

3.数据批处理DataLoader

DataLoader模块直接读取batch数据
from torch.utils.data import Dataloader

 # 使用Dataloader进一步封装Dataset
 dataloader = Dataloader(dataset, batch_size=10, shuffle=True, num_workers=8)
 #参数含义 (Dataset的实例,批量batch的大小,是否打乱数据参数,使用几个线程来加载数据)

总结

以上就是今天介绍的有关数据处理方面的内容,从当前较为主流的公开数据集,然后介绍数据处理流程,从制作数据集、数据增强、数据批处理3个方面介绍PyTorch中相关的使用方法。数据已准备妥当,搭建合适的模型才能让安静的数据说话,接下来我们一起期待Pytorch系列入门7----网络结构

猜你喜欢

转载自blog.csdn.net/weixin_46417939/article/details/128208615