0018-pytorch-迁移学习范本-01

# -*- coding: utf-8 -*-
# @Time    : 2021/1/19 10:39
# @Author  : Johnson

import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.model_selection import KFold
from torch.autograd import Variable
import torch.optim as optim
import time
import copy
import shutil
import sys
# import scikitplot as skplt
import matplotlib.pyplot as plt
import pandas as pd

plt.switch_backend('agg')
N_CLASSES = 2
BATCH_SIZE = 8
DATA_DIR = './data'
LABEL_DICT = {0: 'class_1', 1: 'class_2'}

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(100)

def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
    since = time.time()
    # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
    best_model_wts = copy.deepcopy(model.state_dict())
    # best_acc = 0.0
    # 初始auc
    best_auc = 0.0
    best_desc = [0, 0, None]
    best_img_name = None
    plt_auc = [None, None]

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('- ' * 50)

        for phase in ['train', 'val']:
            if phase == 'train':
                # 训练的时候进行学习率规划,其定义在下面给出
                scheduler.step()
                model.train(True)
            else:
                model.train(False)
            phase_pred = np.array([])
            phase_label = np.array([])
            img_name = np.zeros((1, 2))
            prob_pred = np.zeros((1, 2))
            running_loss = 0.0
            running_corrects = 0
            # 这样迭代方便跟踪图片路径,输出错误图片名称
            for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
                inputs, labels = data
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # 梯度参数设为0
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)

                # backward + 训练阶段优化
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                if phase == 'val':
                    img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
                    prob = outputs.data.cpu().numpy()
                    prob_pred = np.append(prob_pred, prob, axis=0)

                phase_pred = np.append(phase_pred, preds.cpu().numpy())
                phase_label = np.append(phase_label, labels.data.cpu().numpy())
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data).float()
            print()
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            epoch_auc = roc_auc_score(phase_label, phase_pred)
            print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, epoch_auc))
            report = classification_report(phase_label, phase_pred, target_names=class_names)
            print(report)

            img_name = zip(img_name[1:], phase_pred)
            # 当验证时遇到了更好的模型则予以保留
            if phase == 'val' and epoch_auc > best_auc:
                best_auc = epoch_auc
                best_desc = epoch_acc, epoch_auc, report
                best_img_name = img_name
                # 深拷贝模型参数
                best_model_wts = copy.deepcopy(model.state_dict())
                plt_auc = phase_label, prob_pred[1:]

        print()
    print(plt_auc[0].shape, plt_auc[1].shape)
    csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
                                                                                         best_desc[2])
    report_file.write(reports)
    print(reports)
    print('List the wrong judgement img ...')
    count = 0
    for i in best_img_name:
        actual_label = int(i[0][1])
        pred_label = i[1]
        if actual_label != pred_label:
            tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
                       f'pred: {LABEL_DICT[pred_label]}'
            print(tmp_word)
            label_file.write(tmp_word + '\n')
            count += 1
    print(f'This fold has {count} wrong records ...')

    # 载入最优模型参数
    model.load_state_dict(best_model_wts)
    return model

def plot_img():
    for i, data in enumerate(dataloaders['train']):
        inputs, classes = data
        out = torchvision.utils.make_grid(inputs)
        imshow(out, title=[class_names[x] for x in classes])


# 此函数可以修改适用于自己项目的图片文件名
def move_file(data, file_path, dir_path, root_path):
    label_0 = 'class_2'
    label_1 = 'class_1'
    print(f'start copy the {file_path} file ...')
    os.chdir(dir_path)
    if os.path.exists(file_path):
        print(f'Find exist {file_path} file, the file will be dropped.')
        shutil.rmtree(os.path.join(root_path, dir_path, file_path))
        print(f'Finish drop the {file_path} file.')

    os.mkdir(file_path)
    tmp_path = os.path.join(os.getcwd(), file_path)
    tmp_pre_path = os.getcwd()
    for d in data:
        pre_path = os.path.join(tmp_pre_path, d)
        os.chdir(tmp_path)
        if d[:2] == label_0:
            if not os.path.exists(label_0):
                os.mkdir(label_0)
            cur_path = os.path.join(tmp_path, label_0, d)
            shutil.copyfile(pre_path, cur_path)
        if d[:2] == label_1:
            if not os.path.exists(label_1):
                os.mkdir(label_1)
            cur_path = os.path.join(tmp_path, label_1, d)
            shutil.copyfile(pre_path, cur_path)
    print('finish this work ...')


if __name__ == "__main__":
    if not os.path.exists('roc_img'):
        os.mkdir('roc_img')
    if not os.path.exists('prob_result'):
        os.mkdir('prob_result')
    if not os.path.exists('report'):
        os.mkdir('report')
    if not os.path.exists('error_record'):
        os.mkdir('error_record')
    if not os.path.exists('model'):
        os.mkdir('model')
    label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w')

    kf = KFold(n_splits=5, shuffle=True, random_state=1)
    origin_path = '/home/project/'
    dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))])

    for m, n in enumerate(kf.split(dd_list), start=1):
        report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
        print(f'The {m} fold for copy file and training ...')
        move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
        os.chdir(origin_path)
        move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
        os.chdir(origin_path)
        data_transforms = {
            'train': transforms.Compose([
                # 裁剪到224,224
                transforms.RandomResizedCrop(224),
                # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
                transforms.RandomHorizontalFlip(),
                # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),  # HSV以及对比度变化
                transforms.ToTensor(),
                # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }

        image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
                                                  data_transforms[x])
                          for x in ['train', 'val']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
                                                      shuffle=True, num_workers=8, pin_memory=False)
                       for x in ['train', 'val']}

        dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

        class_names = image_datasets['train'].classes
        size = len(class_names)
        print('label mapping: ')
        print(image_datasets['train'].class_to_idx)
        use_gpu = torch.cuda.is_available()
        model_ft = None
        if sys.argv[1] == 'resnet':
            model_ft = models.resnet50(pretrained=True)
            num_ftrs = model_ft.fc.in_features
            model_ft.fc = nn.Sequential(
                nn.Linear(num_ftrs, N_CLASSES),
                nn.Sigmoid()
            )

        # 这边可以自行把inception模型加进去
        if sys.argv[1] == 'inception':
            raise Exception("not provide inception model ...")
            # model_ft = models.inception_v3(pretrained=True)

        if sys.argv[1] == 'desnet':
            model_ft = models.densenet121(pretrained=True)
            num_ftrs = model_ft.classifier.in_features
            model_ft.classifier = nn.Sequential(
                nn.Linear(num_ftrs, N_CLASSES),
                nn.Sigmoid()
            )
            # use_gpu = False

        if use_gpu:
            model_ft = model_ft.cuda()

        criterion = nn.CrossEntropyLoss()
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
        # 每7个epoch衰减0.1倍
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
        model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
        print('Start save the model ...')
        torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
        print(f'The mission of the fold {m} finished.')
        print('# '*50)
        report_file.close()
    label_file.close()

猜你喜欢

转载自blog.csdn.net/zhonglongshen/article/details/112801790