【onnxruntime】onnx模型推理imagenet数据集验证精度

1 背景

onnx模型推理单张图片,网上的教程非常多,我自己以前也写了很多这些内容,但如何推理整个数据集来验证精度呢?

如果你只是为了验证导出的onnx模型精度如何,可以参考这篇文章。

为了保证模型前后处理完全一致,前后处理都直接复用原本的代码,输入输出数据涉及到tensor和numpy转换时直接用torch.from_numpy和.numpy实现。

到嵌入式开发板上跑的话,前后处理都是需要自己写的,而且无法依赖torch。

2 评测Imagenet数据集

imagenet 验证集val,内部有1000个文件夹,每个文件夹下对应有50张图片。
pytorch默认使用PIL读取,刚读取的图片,像素顺序RGB,layout:NHWC
经过transforms.ToTensor(),像素顺序RGB,layout:NCHW。当然,transforms.ToTensor()还有数据归一化(除以255)的作用,具体细节可参考另一篇博客不使用torchvision.transforms 对图片预处理python实现

主程序如下,主要修改该代码即可:

import torch
import torch.nn as nn
import sys
import os
import time
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import utils.common as utils		# 下面给出代码
from tqdm import tqdm


class Data:
    def __init__(self, data_path):
        scale_size = 224

        valdir = os.path.join(data_path, 'val')
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        testset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.Resize(scale_size),
                transforms.ToTensor(),	
                normalize,
            ]))

        self.loader_test = DataLoader(
            testset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            pin_memory=True)

def test_onnxruntime(ort_session, testLoader, logger, topk=(1,)):
    accuracy = utils.AverageMeter('Acc@1', ':6.2f')
    top5_accuracy = utils.AverageMeter('Acc@5', ':6.2f')

    start_time = time.time()
    testLoader = tqdm(testLoader, file=sys.stdout)
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testLoader):
            inputs_origin = inputs
            inputs, targets = inputs.numpy(), targets
            ort_inputs = {
    
    "input1": inputs}
            outputs = ort_session.run(None, ort_inputs)
            outputs = torch.from_numpy(outputs[0])

            predicted = utils.accuracy(outputs, targets, topk=topk)
            accuracy.update(predicted[0], inputs_origin.size(0))
            top5_accuracy.update(predicted[1], inputs_origin.size(0))

        current_time = time.time()
        logger.info(
            'Test Top1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'
                .format(float(accuracy.avg), float(top5_accuracy.avg), (current_time - start_time))
        )

    return top5_accuracy.avg, accuracy.avg

def onnx_inference_imagenet():
    job_dir = './experiment'
    logger = utils.get_logger(os.path.join(job_dir + 'logger.log'))

    # Data
    print('==> Preparing data..')
    data_path = '/home/users/dataset/imagenet/'
    # data_path = '/data/horizon_j5/data/imagenet/'
    loader = Data(data_path)
    testLoader = loader.loader_test

    onnx_path = "./weights/resnet50/resnet50_pruned.onnx"
    #---------------------------------------------------------#
    #   使用onnxruntime
    #---------------------------------------------------------#
    import onnxruntime
    ort_session = onnxruntime.InferenceSession(onnx_path)
    #---------------------------------------------------------#
    #   进test_onnxruntime函数
    #---------------------------------------------------------#
    test_onnxruntime(ort_session, testLoader, logger, topk=(1, 5))

if __name__ == '__main__':
    onnx_inference_imagenet()

在utils文件夹下,有common.py文件,其中代码如下:

import os
import sys
import shutil
import time, datetime
import logging
import numpy as np
from PIL import Image
from pathlib import Path

import torch
import torch.nn as nn
import torch.utils


'''record configurations'''
class record_config():
    def __init__(self, args):
        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
        today = datetime.date.today()

        self.args = args
        self.job_dir = Path(args.job_dir)

        def _make_dir(path):
            if not os.path.exists(path):
                os.makedirs(path)

        _make_dir(self.job_dir)

        config_dir = self.job_dir / 'config.txt'
        #if not os.path.exists(config_dir):
        if args.resume:
            with open(config_dir, 'a') as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')
        else:
            with open(config_dir, 'w') as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')


def get_logger(file_path):

    logger = logging.getLogger('gal')
    log_format = '%(asctime)s | %(message)s'
    formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.setLevel(logging.INFO)

    return logger

#label smooth
class CrossEntropyLabelSmooth(nn.Module):

  def __init__(self, num_classes, epsilon):
    super(CrossEntropyLabelSmooth, self).__init__()
    self.num_classes = num_classes
    self.epsilon = epsilon
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, inputs, targets):
    log_probs = self.logsoftmax(inputs)
    targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
    loss = (-targets * log_probs).mean(0).sum()
    return loss


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def save_checkpoint(state, is_best, save):
    if not os.path.exists(save):
        os.makedirs(save)
    filename = os.path.join(save, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(save, 'model_best.pth.tar')
        shutil.copyfile(filename, best_filename)


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res



def progress_bar(current, total, msg=None):
    _, term_width = os.popen('stty size', 'r').read().split()
    term_width = int(term_width)

    TOTAL_BAR_LENGTH = 65.
    last_time = time.time()
    begin_time = last_time

    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()


def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

猜你喜欢

转载自blog.csdn.net/weixin_45377629/article/details/126729376