API_Net官方代码之训练网络

导入包:

import argparse
import os
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
from models import API_Net
from datasets import RandomDataset, BatchDataset, BalancedBatchSampler
from utils import accuracy, AverageMeter, save_checkpoint

**1)相关参数:**如果显存不够用的话,就需要调整n_classes,n_samples。

经常用到的参数
num_works 获取批量样本时,相当于线程,有多个途径来提供这一个batch的样本,需要根据cpu核数以及RAM 来设置
batch_size 一批样本的个数
epochs 循环次数
start_epoch 因意外中断,重新启动训练时的epoch
learning_rate 学习率
momentum
weight_decay
resume 是否恢复存在的模型参数
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--exp_name', default=None, type=str,
                    help='name of experiment')
parser.add_argument('--data', metavar='DIR', default='',
                    help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', #需要用到
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=150, type=int, metavar='N', #总epoch
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',#意外停止训练后,重启训练开始的epoch
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=6, type=int,  
#此处的batch_size是指对验证集进行的设置,训练集的每批样本数通过设置batchsampler就设置了。
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=1, type=int, #打印频率,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--evaluate-freq', default=10, type=int,
                    help='the evaluation frequence')
parser.add_argument('--resume', default='./model_best.pth.tar', type=str, metavar='PATH', #恢复模型
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--n_classes', default=4, type=int,
                    help='the number of classes')
parser.add_argument('--n_samples', default=4, type=int,
                    help='the number of samples per class')

2)设置全局变量以及硬件设备:

best_prec1 = 0 #其为是否是最好的预测的标志量
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3)训练:
参数定义;损失函数定义,model.train();训练网络;更新参数;更新统计值;打印训练结果;epoch后的测试;模型参数的保存。

def train(train_loader, model, criterion, optimizer_conv, scheduler_conv, optimizer_fc, scheduler_fc, epoch, step):
#参数定义
    global best_prec1
    batch_time = AverageMeter()
    data_time = AverageMeter()
    softmax_losses = AverageMeter()
    rank_losses = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

#损失函数定义
    # switch to train mode
    end = time.time()
    rank_criterion = nn.MarginRankingLoss(margin=0.05)
    softmax_layer = nn.Softmax(dim=1).to(device)

    for i, (input, target) in enumerate(train_loader):
        model.train()
#训练网络
        # measure data loading time
        data_time.update(time.time() - end) #循环一个step需要的时间
        input_var = input.to(device)
        target_var = target.to(device).squeeze()

        # compute output
        logit1_self, logit1_other, logit2_self, logit2_other, labels1, labels2 = model(input_var, target_var, 
         		flag='train')#logit1_self与logit1_other的标签均为labels1,logit均为(8,200)
        batch_size = logit1_self.shape[0]
        labels1 = labels1.to(device)
        labels2 = labels2.to(device)

        self_logits = torch.zeros(2 * batch_size, 200).to(device) #(16,200) #(16,200),前8为feature1,后8为feature2
        other_logits = torch.zeros(2 * batch_size, 200).to(device)

        self_logits[:batch_size] = logit1_self
        self_logits[batch_size:] = logit2_self
        other_logits[:batch_size] = logit1_other
        other_logits[batch_size:] = logit2_other

        # compute loss
        #softmax_loss
        logits = torch.cat([self_logits, other_logits], dim=0)
        targets = torch.cat([labels1, labels2, labels1, labels2], dim=0)
        softmax_loss = criterion(logits, targets)#针对一个批次的每个图片的所有特征均进行判断
		
		#rank_loss
        self_scores = softmax_layer(self_logits)[torch.arange(2 * batch_size).to(device).long(),#获取self_logits的softmax形式中,(16,200)预测向量的标签对应位置的概率值
                                                 torch.cat([labels1, labels2], dim=0)]
        other_scores = softmax_layer(other_logits)[torch.arange(2 * batch_size).to(device).long(),
                                                   torch.cat([labels1, labels2], dim=0)]
        flag = torch.ones([2 * batch_size, ]).to(device)
        rank_loss = rank_criterion(self_scores, other_scores, flag)
        #If target = 1,then it assumed the first input should be ranked higher (have a larger value) than the second input

        loss = softmax_loss + rank_loss #计算的都是平均后的
#更新统计值
        # measure accuracy and record loss
        prec1 = accuracy(logits, targets, 1) #当前批对应的top1accuracy
        prec5 = accuracy(logits, targets, 5)

        losses.update(loss.item(), 2 * batch_size) #总的损失
        softmax_losses.update(softmax_loss.item(), 4 * batch_size)  #总的分类损失
        rank_losses.update(rank_loss.item(), 2 * batch_size) #总的排名损失
        top1.update(prec1, 4 * batch_size) #总的top1
        top5.update(prec5, 4 * batch_size) #总的top5
#更新参数
        # compute gradient and do SGD step
        optimizer_conv.zero_grad()
        optimizer_fc.zero_grad()
        loss.backward()
        if epoch >= 8:
            optimizer_conv.step()
        optimizer_fc.step()
        scheduler_conv.step()
        scheduler_fc.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
		
#打印训练结果
        if i % args.print_freq == 0:
            print('Time: {time}\nStep: {step}\t Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'SoftmaxLoss {softmax_loss.val:.4f} ({softmax_loss.avg:.4f})\t'
                  'RankLoss {rank_loss.val:.4f} ({rank_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, softmax_loss=softmax_losses, rank_loss=rank_losses,
                top1=top1, top5=top5, step=step, time=time.asctime(time.localtime(time.time()))))
		#step是总的批次,epoch是圈数
		
#测试
		#每一epoch后,都进行测试
        if i == len(train_loader) - 1:
            val_dataset = RandomDataset(transform=transforms.Compose([
                transforms.Resize([512, 512]),
                transforms.CenterCrop([448, 448]),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)
                )]))
            val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)
            prec1 = validate(val_loader, model, criterion)
#保存模型参数
            # remember best prec@1 and save checkpoint,每次的都要保存,以防止意外中断,并且要保存最好的
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
    
    
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer_conv': optimizer_conv.state_dict(),
                'optimizer_fc': optimizer_fc.state_dict(),
            }, is_best)

        step = step + 1
    return step

4)测试:

def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    softmax_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    end = time.time()

    with torch.no_grad():  #在测试时可以使用这个,不用保存梯度信息,减少内存占用
        for i, (input, target) in enumerate(val_loader):

            input_var = input.to(device)
            target_var = target.to(device).squeeze()

            # compute output
            logits = model(input_var, targets=None, flag='val')
            softmax_loss = criterion(logits, target_var)

            prec1 = accuracy(logits, target_var, 1)
            prec5 = accuracy(logits, target_var, 5)
            softmax_losses.update(softmax_loss.item(), logits.size(0))
            top1.update(prec1, logits.size(0))
            top5.update(prec5, logits.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Time: {time}\nTest: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'SoftmaxLoss {softmax_loss.val:.4f} ({softmax_loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                    i, len(val_loader), batch_time=batch_time, softmax_loss=softmax_losses,
                    top1=top1, top5=top5, time=time.asctime(time.localtime(time.time()))))
        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return top1.avg

5)主函数

def main():
    global args, best_prec1 #两个全局变量,args传递参数,best_prec最优预测,将其设置为global是因为其在恢复模型参数时,有着一项,并且在train时也用到此变量,如果不设置为全局,那么就需要参数传递与返还。
    args = parser.parse_args()
    
    torch.manual_seed(2) #设置随机数的目的在于,每次运行程序时,产生的随机数都是和上次运行程序一样
    torch.cuda.manual_seed_all(2)
    np.random.seed(2)

    # create model
    model = API_Net()
    model = model.to(device)
    model.conv = nn.DataParallel(model.conv)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer_conv = torch.optim.SGD(model.conv.parameters(), args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.weight_decay)

    fc_parameters = [value for name, value in model.named_parameters() if 'conv' not in name] #获取某些层参数的方法
    optimizer_fc = torch.optim.SGD(fc_parameters, args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
    if args.resume: 
        if os.path.isfile(args.resume):#保存参数的
            print('loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer_conv.load_state_dict(checkpoint['optimizer_conv'])
            optimizer_fc.load_state_dict(checkpoint['optimizer_fc'])
            print('loaded checkpoint {}(epoch {})'.format(args.resume, checkpoint['epoch']))
        else:
            print('no checkpoint found at {}'.format(args.resume))

    cudnn.benchmark = True #设置此参数,如果输入数据大小相差无几,可加快运算
    # Data loading code
    train_dataset = BatchDataset(transform=transforms.Compose([
        transforms.Resize([512, 512]),
        transforms.RandomCrop([448, 448]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        )]))

    train_sampler = BalancedBatchSampler(train_dataset, args.n_classes, args.n_samples)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_sampler=train_sampler,
        num_workers=args.workers, pin_memory=True)
      
    #100*len(train_loader)正好是step的个数,而从此参数是cos周期的一半,也就是学习率从最大到最小变化,并且在中间变化的最快
    scheduler_conv = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_conv, 100 * len(train_loader))
    scheduler_fc = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_fc, 100 * len(train_loader))

    step = 0
    print('START TIME:', time.asctime(time.localtime(time.time())))

    for epoch in range(args.start_epoch, args.epochs):
        step = train(train_loader, model, criterion, optimizer_conv, scheduler_conv, optimizer_fc, scheduler_fc, epoch,
                     step)

猜你喜欢

转载自blog.csdn.net/YJYS_ZHX/article/details/113540825