目录
#模型优化,不一定非要改模型的参数,也可以通过学习率衰减(超参数设置)、数据增强等方法进行优化
#在这个项目中,使用数据增强,减少了过拟合
1.导包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
2.加载数据、拼接训练与测试数据的文件夹路径
base_dir = './dataset'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')
3数据预处理
3.1数据增强
#图片数据增强的常用方法
transforms.RandomCrop # 随机位置的裁剪 , CenterCrop 中间位置裁剪
transforms.RandomRotation # 随机旋转
transforms.RandomHorizontalFlip() # 水平翻转
transforms.RandomVerticalFlip() # 垂直翻转
transforms.ColorJitter(brightness) # 亮度
transforms.ColorJitter(contrast) # 对比度
transforms.ColorJitter(saturation) # 饱和度
transforms.ColorJitter(hue) #图像抖动
transforms.RandomGrayscale() # 随机灰度化.
# 数据增强只会加在训练数据上. 不一定使用了数据增强,训练效果就一定好!!!
train_transform = transforms.Compose([
transforms.Resize((224, 224)), #原论文中的统一的尺寸参数要求
transforms.RandomCrop(192), #从原图中切出来的尺寸大小
transforms.RandomHorizontalFlip(), #水平翻转
transforms.RandomVerticalFlip(), #垂直翻转
transforms.RandomRotation(0.4), #随机旋转 0.4是旋转的角度比例
# transforms.ColorJitter(brightness=0.5),
# transforms.ColorJitter(contrast=0.5),
transforms.ToTensor(),
# 正则化
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
#测试数据不进行图片的数据增强
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
# 正则化
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
3.2用分类存储的图片数据创建dataloader
train_ds = torchvision.datasets.ImageFolder(train_dir, transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=test_transform)
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)
4.加载预训练好的模型 (迁移学习)
# 加载预训练好的模型
model