一个简单的图像分类项目(八)编写脚本:使用训练好的模型预测

 

利用训练好的模型预测,lib.predict.py: 

从训练集和测试集随机各选择100张图片测试

import random
import torchvision.transforms as transforms
from load_imags import train_list, test_list
from script.setting import *
from nets import *


def main():
    models_list = os.listdir(model_path)   # 获取训练好的网络列表
    print('Please choose a network to predict:')
    for i, model in enumerate(models_list):
        print('input ', i, ':', model)
    # 选择网络
    while True:
        try:
            net_choose = int(input(''))
            _ = models_list[net_choose]   # 获取网络名称
            if net_choose < len(models_list):
                last_dash_index = _.rfind('-')  # 获取最后一个'-'的位置
                dot_index = _.find('.')   # 获取最后一个'.'的位置
                net_name = _[last_dash_index+1:dot_index]  # 获取网络名称
                print('You have chosen the ' + net_name + ' network, start predicting.')
                break
        except:
            print('Please input a correct number!')


    # 加载网络
    net = eval(net_name + '()')
    net.load_state_dict(torch.load(os.path.join(model_path, models_list[net_choose])))
    net.to(device)
    net.eval()   # 转为测试模式

    # 准备图片
    pre_imgs = []
    pre_imgs.extend(random.sample(train_list, 100))  # 随机选择100张图片作为测试图片
    pre_imgs.extend(random.sample(test_list, 100))

    # 预测
    correct = 0
    for img in pre_imgs:
        im_label_name = img.split('\\')[-2]   # 获取标签名称
        # 定义预处理的方法
        predict_transform = transforms.Compose([
            transforms.ToTensor(),  # 将图片转换为Tensor
            transforms.Normalize(normalize_mean,  # 标准化
                                 normalize_std)
        ])
        image = Image.open(img)  # 读取图片
        image = predict_transform(image)  # 预处理
        image = image.unsqueeze(0)  # 增加一维,因为输入到网络中是4维的,[batch_size, channel, height, width]
        image = image.to(device)  # 转为GPU
        output = net(image)  # 预测
        _, predicted = torch.max(output.data, 1)  # 获取预测结果

        im_label_pred = label_name[predicted.item()]
        print('Label of image is:', im_label_name)
        print('The predicted label is: ', im_label_pred)
        print('-'*50)
        if im_label_pred == im_label_name:
            correct += 1

        while True:
            x = input('Input any key to continue: ')
            if x:
                break

    print('The accuracy of the network is: ', 100 * correct/len(pre_imgs), '%.')




if __name__ == '__main__':
    main()
    

运行结果:

Input any key to continue:  
Label of image is: deer
The predicted label is:  deer
--------------------------------------------------
Input any key to continue:  
Label of image is: frog
The predicted label is:  frog
--------------------------------------------------
Input any key to continue: 
Input any key to continue:  
Label of image is: dog
The predicted label is:  dog
--------------------------------------------------
Input any key to continue:  
Label of image is: cat
The predicted label is:  dog
--------------------------------------------------
Input any key to continue:  
The accuracy of the network is:  85.5 %.

         暂时告一段落,从头到尾做了两遍CIFAR,收获还是有的,算是半只脚踏入了深度学习的门槛,如果有了闲暇,进一步完善,提高识别率和加入前端界面(有了闲暇做事情的flag太多了,基本实现不了 :(   )