关于pytorch图像处理模块的数据处理

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/shaoyou223/article/details/84561462

文章参考:chsasank.github.io

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
#搭建图像处理的框架
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
#图像保存路径
data_dir = 'data'
#遍历图像
#
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train', 'val']}
#查看对应文件夹的labels  输出结果是:{'bees': 1, 'ants': 0}
print(datasets.ImageFolder('data/train').class_to_idx)
#查看对应文件夹的labels  输出结果是一个列表,比如其中某个元素为如下形式 ('data/train/ants/0013035.jpg', 0),
#print(datasets.ImageFolder('data/train').imgs)

#output class_names  and the result is  ['ants', 'bees']
class_names = image_datasets['train'].classes

#输出dataset_sizes为{train的图片数与val的图片数}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

#变成张量数据 模型可以直接调用,具体可参考莫烦pytorch教程
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4) for x in ['train', 'val']}

for inputs,labels in dataloaders['train']:
    #这里输出的就是batch图片的内容,维度为4*3*228*228

    print(inputs)

print(dataloaders['train'][0])
for i in range(len(image_datasets['train'].imgs)):
    if i < 10:
        #显示输出为单个的元素  ('data/train/ants/1030023514_aad5c608f9.jpg', 0)
        #print(image_datasets['train'].imgs[i])
        #如果需要显示单张图像的类别
        print(image_datasets['train'].classes)               
print(len(image_datasets['train'].imgs))

猜你喜欢

转载自blog.csdn.net/shaoyou223/article/details/84561462