图像分类pytorch dataset 代码

from torchvision.datasets import ImageFolder
from torchvision import transforms
import os
import glob
from PIL import Image
from torch.utils.data import Dataset
"""
# 方式1

# 加上transforms
normalize = transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
transform = transforms.Compose([
    transforms.RandomCrop(180),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
    normalize
])

dataset = ImageFolder('./data', transform=transform)

print(dataset.classes)  #根据分的文件夹的名字来确定的类别
print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

"""

"""
方式2
"""
class_name=['drawings', 'hentai', 'netural', 'porn', 'sexy']
"""
# 生成txt文件
# with open('./train.txt','w',encoding='utf-8') as f:
#     for index, path in enumerat

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/131617345