利用训练好的模型预测,lib.predict.py:
从训练集和测试集随机各选择100张图片测试
import random
import torchvision.transforms as transforms
from load_imags import train_list, test_list
from script.setting import *
from nets import *
def main():
models_list = os.listdir(model_path) # 获取训练好的网络列表
print('Please choose a network to predict:')
for i, model in enumerate(models_list):
print('input ', i, ':', model)
# 选择网络
while True:
try:
net_choose = int(input(''))
_ = models_list[net_choose] # 获取网络名称
if net_choose < len(models_list):
last_dash_index = _.rfind('-') # 获取最后一个'-'的位置
dot_index = _.find('.') # 获取最后一个'.'的位置
net_name = _[last_dash_index+1:dot_index] # 获取网络名称
print('You have chosen the ' + net_name + ' network, start predicting.')
break
except:
print('Please input a correct number!')
# 加载网络
net = eval(net_name + '()')
net.load_state_dict(torch.load(os.path.join(model_path, models_list[net_choose])))
net.to(device)
net.eval() # 转为测试模式
# 准备图片
pre_imgs = []
pre_imgs.extend(random.sample(train_list, 100)) # 随机选择100张图片作为测试图片
pre_imgs.extend(random.sample(test_list, 100))
# 预测
correct = 0
for img in pre_imgs:
im_label_name = img.split('\\')[-2] # 获取标签名称
# 定义预处理的方法
predict_transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor
transforms.Normalize(normalize_mean, # 标准化
normalize_std)
])
image = Image.open(img) # 读取图片
image = predict_transform(image) # 预处理
image = image.unsqueeze(0) # 增加一维,因为输入到网络中是4维的,[batch_size, channel, height, width]
image = image.to(device) # 转为GPU
output = net(image) # 预测
_, predicted = torch.max(output.data, 1) # 获取预测结果
im_label_pred = label_name[predicted.item()]
print('Label of image is:', im_label_name)
print('The predicted label is: ', im_label_pred)
print('-'*50)
if im_label_pred == im_label_name:
correct += 1
while True:
x = input('Input any key to continue: ')
if x:
break
print('The accuracy of the network is: ', 100 * correct/len(pre_imgs), '%.')
if __name__ == '__main__':
main()
运行结果:
Input any key to continue:
Label of image is: deer
The predicted label is: deer
--------------------------------------------------
Input any key to continue:
Label of image is: frog
The predicted label is: frog
--------------------------------------------------
Input any key to continue:
Input any key to continue:
Label of image is: dog
The predicted label is: dog
--------------------------------------------------
Input any key to continue:
Label of image is: cat
The predicted label is: dog
--------------------------------------------------
Input any key to continue:
The accuracy of the network is: 85.5 %.
暂时告一段落,从头到尾做了两遍CIFAR,收获还是有的,算是半只脚踏入了深度学习的门槛,如果有了闲暇,进一步完善,提高识别率和加入前端界面(有了闲暇做事情的flag太多了,基本实现不了 :( )