自定义数据集Dataset

目录

1.导包 

2. 获取图片路径

3.建立图片类别和索引之间的映射关系

3. 1建立图片类别的数字映射关系第二种方法:生成所有图片的label 

4.借助ndarray的索引取值的方法, 打乱数据 

 5.手动的划分一下训练数据和测试数据

6.transforms.Compose数据预处理、转化 

7.定义类:实现自定义重写数据集Dataset 

8.破损图片检测显示 

9.创建训练与测试数据的dataloader  (调用 自定义封装的数据集MyDataset )

10.定义神经网络模型 

11.定义设备 

12.创建模型并拷到GPU上,定义优化器、损失函数 

13.定义训练过程 

14.运行测试 


# Dataset: __len__, __getitem__都可以成为Dataset

# __len__: 获取长度:可迭代对象的元素数量

#__getitem__:按索引获取元素

# dataset[0] = dataset.__getitem__[0]

1.导包 

import torch
import numpy as np
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F


import glob  #获取路径下所有图片路径

from PIL import Image   #pyTorch中处理图片数据的模块, 根据路径打开 显示图片

2. 获取图片路径

all_img_path = glob.glob(r'E:\PYTHON学习资料2023-7-5\10深度学习\1深度学习\02TyTorch\day56_dropout和bn\代码\dataset\*.jpg')
all_img_path[:5]  #查看 获取的前五个图片路径
['E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy1.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy10.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy100.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy101.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy102.jpg']

3.建立图片类别和索引之间的映射关系

# 建立图片类别和索引之间的映射关系
species = ['cloudy', 'rain', 'shine', 'sunrise']
#建立映射关系的第一种方法
#enumerate()使可迭代的对象species列表 生成 每个元素对应的索引编号
species_to_idx = dict((c, i) for i, c in enumerate(species))
species_to_idx  #查看生成的字典映射关系
{'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
species_to_idx.items()   #是一个可迭代的列表对象
dict_items([('cloudy', 0), ('rain', 1), ('shine', 2), ('sunrise', 3)])
# 调换一下key和value的顺序
#species_to_idx是一个字典
#从原字典species_to_idx中取出键与值,再调换位置重新组成一个新的字典
idx_to_species = dict((v, k) for k, v in species_to_idx.items())
idx_to_species  #查看调换之后的字典映射关系
{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
'cloudy' in all_img_path[0]  #可判断字符串'cloudy' 是否在all_img_path[0]这张图片的路径中
True

3. 1建立图片类别的数字映射关系第二种方法:生成所有图片的label 

# 生成所有图片的label 
all_labels = []

for img in all_img_path:             #遍历每一张图的路径
    for i, c in enumerate(species):  #枚举图片的每一个类别
        if c in img:                 #若字符串“类别名称” 在 遍历的这张图片路径中
            all_labels.append(i)     #将位置索引i添加到标签列表中all_labels
all_img_path[:5]   #查看 前五张 图片的路径
['E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy1.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy10.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy100.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy101.jpg',
 'E:\\PYTHON学习资料2023-7-5\\10深度学习\\1深度学习\\02TyTorch\\day56_dropout和bn\\代码\\dataset\\cloudy102.jpg']
all_labels
[0,
 0,
 0,
 0,
 0,
...
 1,
 1,
 1,
 ...
 2,
 2,
 2,
 2,
...,
3,
3,
3]
all_labels[:5]  #查看标签列表 的前五个
[0, 0, 0, 0, 0]

4.借助ndarray的索引取值的方法, 打乱数据 

# 借助ndarray的索引取值的方法, 打乱数据
index = np.random.permutation(len(all_img_path))   ##生成 图片数据总量的 随机数 ,作为索

猜你喜欢

转载自blog.csdn.net/Hiweir/article/details/147062539