Pytorch实战入门(三):迁移学习

Pytorch实战入门(一):MLP
Pytorch实战入门(二):CNN与MNIST
Pytorch实战入门(三):迁移学习

前言

数据集下载地址,提取码:6smh
数据集格式满足 torchvision.datasets.ImageFolder 读取要求(根目录/类别名/图像名.jpg

主要涉及

  • 获取pytorch定义好的模型结构
  • 读取网络、冻结网络、保存网络
  • 简单修改模型满足任务需求

1. 代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models

import numpy as np
import matplotlib.pyplot as plt
import os
import copy

def train(model, dataloader, loss_fn, optimizer, epoch):
    model.train()
    train_loss = 0.
    train_corrects = 0.
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        with torch.autograd.set_grad_enabled(True):
            outputs = model(inputs)
            loss = loss_fn(outputs, labels) 

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        preds = outputs.argmax(dim=1)
        train_loss += loss.item() * inputs.size(0)
        train_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item()
                
    epoch_loss = train_loss / len(dataloader.dataset)
    epoch_acc = train_corrects / len(dataloader.dataset)
    print("epoch {} train loss: {}, acc: {}".format(epoch, epoch_loss, epoch_acc))
    return epoch_loss, epoch_acc

def test(model, dataloader):
    model.eval()
    test_corrects = 0.
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            test_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item()
            
        epoch_acc = test_corrects / len(dataloader.dataset)
        print("test acc: {}".format(epoch_acc))
    return epoch_acc
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path = "./hymenoptera_data"  # 数据集路径
input_size = 224  # 输入图像 224*224
num_classes = 2   # 类别数量 2 ants和bees
batch_size = 32
epochs = 20
lr = 0.001

pretrained = False
feature_extract = False

data_transforms = {
    
    
    "train": transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.515, 0.469, 0.341], [0.271, 0.255, 0.281])
    ]),
    "val": transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.515, 0.469, 0.341], [0.271, 0.255, 0.281])
    ])
}

# 数据集
image_datasets = {
    
    x: datasets.ImageFolder(os.path.join(data_path, x), data_transforms[x]) for x in ["train", "val"]}
dataloaders = {
    
    x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=batch_size,
                                              shuffle=True, num_workers=1) for x in ["train", "val"]
              }
train_dataloader = dataloaders["train"]
val_dataloader = dataloaders["val"]
# 模型
# 拿到 pytorch定义的 resnet18
# pretrained=True则自动下载在ImageNet上训练好的模型并读取参数
model = models.resnet18(pretrained=pretrained)
# feature_extract==True 则网络已有的参数不参与训练
if feature_extract:
    for param in model.parameters():
        param.requires_grad = False
        
in_features = model.fc.in_features  # 拿到全连接层的输入维度
# 原本网络是1000类分类,即最后一层是 nn.Linear(in_features, 1000)
# 而我们的任务是二分类,因此要重新定义一个 nn.Linear
# 如果feature_extract==True,则整个网络只训练这个重新定义的全连接层
model.fc = nn.Linear(in_features, num_classes)
# param = torch.load("./models/best_model3.pt")
# model.load_state_dict(param)
model = model.to(device)

# 优化器
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
                            lr=lr, momentum=0.9)
# 损失函数
loss_fn = nn.CrossEntropyLoss()

train_loss_history = []
train_acc_history = []
val_acc_history = []
for epoch in range(epochs):
    best_model = copy.deepcopy(model.state_dict())
    best_acc = 0.
    
    train_loss, train_acc = train(model, train_dataloader, loss_fn, optimizer, epoch)
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    
    val_acc = test(model, val_dataloader)
    val_acc_history.append(val_acc)
    
    if val_acc > best_acc:
        best_acc = val_acc
        best_model = copy.deepcopy(model.state_dict())
        
torch.save(best_model, "./models/best_model1.pt")

对数据集的 transform 可见 图像变换 torchvision.transforms 笔记

2. 测试

  • 直接训练
      只拿现成的网络结构,从头训练整个网络。
    pretrained = False,不使用预训练模型
    feature_extract = False,不冻结除最后一个全连接层以外的网络参数
    epochs = 20

  • 预训练
      除最后一个全连接层以外的网络参数使用预训练模型参数,并且整个网络一起训练。
    pretrained = True
    feature_extract = False
    epochs = 20

  • 预训练 + 微调
      使用预训练模型参数,但只训练最后一个分类层,保存 best_model3.pt
    pretrained = True
    feature_extract = True
    epochs = 10
      之后读取 best_model3.pt,训练整个网络。
    pretrained = False
    feature_extract = False
    epochs = 10
    param = torch.load("./models/best_model3.pt")
    model.load_state_dict(param)

  • 训练和测试结果
    依次为每个 epoch 的训练损失,训练集上准确率,验证集上准确率
    在这里插入图片描述在这里插入图片描述在这里插入图片描述

3. 调试

3.1 数据

  transforms.Normalize() 参数获取,由于网络是在ImageNet上预训练的,理论上用 ImageNet的标准化参数 比较好 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]

# 对未使用 transforms.Normalize() 的 image_datasets
data1 = [d[0][0].cpu().numpy() for d in image_datasets["train"]]
data2 = [d[0][1].cpu().numpy() for d in image_datasets["train"]]
data3 = [d[0][2].cpu().numpy() for d in image_datasets["train"]]
print(np.mean(data1), np.mean(data2), np.mean(data3))
print(np.std(data1), np.std(data2), np.std(data3))

  看一看数据图像

img = image_datasets["train"]
unloader = transforms.ToPILImage()  # reconvert into PIL image
plt.ion()

def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = unloader(image)
    plt.axis('off')
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated

plt.figure()
imshow(img[1][0], title='Image')

  未使用 transforms.Normalize() 和使用后的两张图像
在这里插入图片描述在这里插入图片描述

3.2 模型

  拿到模型以后若要修改首先要知道模型原本的结构。

model = models.resnet18(pretrained=pretrained)
print(model)

  直接打印模型,就可以看出 resnet-18 由 conv1 + bn1 + relu + maxpool + layer1 + layer2 + layer3 + layer4 + avgpool + fc 构成,用 model.fc 可以直接拿到最后一个全连接层的信息:
Linear(in_features=512, out_features=1000, bias=True)
  任务需求是二分类问题,简单修改 model.fc 即可。

猜你喜欢

转载自blog.csdn.net/weixin_43605641/article/details/111144401