pyTorch-迁移学习-图片数据增强-四种天气图片的多分类问题

目录

1.导包

 2.加载数据、拼接训练与测试数据的文件夹路径

3数据预处理 

3.1数据增强 

3.2用分类存储的图片数据创建dataloader

4.加载预训练好的模型 (迁移学习)

4.1固定、修改预训练好的模型 

5.将模型拷到GPU上 

6.定义优化器与损失函数 

7.学习率衰减 

8.定义训练过程 

9.运行测试 

10.可视化:训练与测试的损失函数、准确率对比 


#模型优化,不一定非要改模型的参数,也可以通过学习率衰减(超参数设置)、数据增强等方法进行优化
#在这个项目中,使用数据增强,减少了过拟合 

1.导包

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms

import os

 2.加载数据、拼接训练与测试数据的文件夹路径

base_dir = './dataset'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')

3数据预处理 

3.1数据增强 

#图片数据增强的常用方法
transforms.RandomCrop    # 随机位置的裁剪 , CenterCrop 中间位置裁剪
transforms.RandomRotation # 随机旋转
transforms.RandomHorizontalFlip() # 水平翻转
transforms.RandomVerticalFlip() # 垂直翻转
transforms.ColorJitter(brightness) # 亮度
transforms.ColorJitter(contrast) # 对比度
transforms.ColorJitter(saturation) # 饱和度
transforms.ColorJitter(hue)   #图像抖动
transforms.RandomGrayscale() # 随机灰度化.

# 数据增强只会加在训练数据上.  不一定使用了数据增强,训练效果就一定好!!!
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),   #原论文中的统一的尺寸参数要求
    transforms.RandomCrop(192),      #从原图中切出来的尺寸大小
    transforms.RandomHorizontalFlip(), #水平翻转
    transforms.RandomVerticalFlip(),   #垂直翻转
    transforms.RandomRotation(0.4),    #随机旋转   0.4是旋转的角度比例
#     transforms.ColorJitter(brightness=0.5),
#     transforms.ColorJitter(contrast=0.5),
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
#测试数据不进行图片的数据增强
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

3.2用分类存储的图片数据创建dataloader

train_ds = torchvision.datasets.ImageFolder(train_dir, transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=test_transform)

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)

4.加载预训练好的模型 (迁移学习)

# 加载预训练好的模型
model 

猜你喜欢

转载自blog.csdn.net/Hiweir/article/details/147041712
今日推荐