API_Net官方代码之utils工具

导入包

import torch
import shutil

二、模块

1)保存模型参数,保存模型状态,状态中可以有模型参数,优化器参数,epoch等。如果是在验证集上表现比之前好,那么就是is_best=True,使用shutil.copyfile(src, des)将src文件直接拷贝到des,如果已经存在,就直接覆盖掉。

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):#state是一个字典,包含优化器、网络等参数
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

2)计算相关统计值, 如时间、top1、top5等,需要有四个值,分别是当前的值val、累计值sum(用于求取平均)、所有数量count、平均值avg。同时设置了归零的函数。


class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

3)计算准确率, 有top1、以及top5。top1是正常的错误率计算,也就是选取最大的概率的标签作为对应的标签,若与真实标签不同,则error。top5是前五大概率中都没有真实标签,才算错误,相当于放宽了标准。

def accuracy(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.

    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """

    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True) 
    correct = ind.eq(targets.view(-1, 1).expand_as(ind)) #每一行最多有一个True或者没有
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

'''
topk(input, dim, replace, p)参数分别表示 批量概率向量、维度、是否有放回(也就是是否可以重复,True可重复)、p为输入向量各个元素的概率
'''



猜你喜欢

转载自blog.csdn.net/YJYS_ZHX/article/details/113538424