完善后的train.py:
import time
from load_imags import train_loader, train_num, test_loader, test_num
from nets import *
def main():
# 定义网络
print('Please choose a network:')
print('1. ResNet18')
print('2. VGG')
# 选择网络
while True:
net_choose = input('')
if net_choose == '1':
net = resnet18_model().to(device)
net_name = 'ResNet18'
print('You have chosen the ResNet18 network, start training.')
break
elif net_choose == '2':
net = vgg_model().to(device)
net_name = 'VGG-simple'
print('You have chosen the VGG network, start training.')
break
else:
print('Please input a correct number!')
# 定义损失函数和优化器
loss_func = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # 优化器使用Adam
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=5,
gamma=0.9) # 学习率衰减, 每5个epoch,学习率乘以0.9
# 训练模型
for epoch in range(num_epoches):
trained_num = 0 # 记录训练过的图片数量
total_correct = 0 # 记录正确数量
print('-' * 100)
print('Epoch {}/{}'.format(epoch + 1, num_epoches))
begin_time = time.time() # 记录开始时间
net.train() # 训练模式
for i, (images, labels) in enumerate(train_loader):
images = images.to(device) # 每batch_size个图像的数据
labels = labels.to(device) # 每batch_size个图像的标签
trained_num += images.size(0) # 记录训练过的图片数量
outputs = net(images) # 前向传播
loss = loss_func(outputs, labels) # 计算损失
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播
optimizer.step() # 优化器更新参数
_, predicted = torch.max(outputs.data, 1) # 预测结果
correct = predicted.eq(labels).cpu().sum() # 计算本batch_size的正确数量
total_correct += correct # 记录正确数量
if (i + 1) % 50 == 0: # 每50个batch_size打印一次
print('trained: {}/{}'.format(trained_num, train_num))
print('Loss: {:.4f}, Accuracy: {:.2f}%'.format(loss.item(), 100 * correct / images.size(0)))
print('-' * 30)
# 每5个epoch,学习率衰减
scheduler.step()
end_time = time.time() # 记录结束时间
print('Each train_epoch take time: {} s'.format(end_time - begin_time))
print('This train_epoch accuracy: {:.2f}%'.format(100 * total_correct / train_num))
print('-' * 60)
tested_num = 0 # 记录测试过的图片数量
total_correct = 0 # 记录正确数量
begin_time = time.time() # 记录开始时间
net.eval() # 测试模式
for i, (images, labels) in enumerate(test_loader):
images = images.to(device) # 每batch_size个图像的数据
labels = labels.to(device) # 每batch_size个图像的标签
tested_num += images.size(0) # 记录测试过的图片数量
outputs = net(images) # 前向传播
loss = loss_func(outputs, labels) # 计算损失
_, predicted = torch.max(outputs.data, 1) # 预测结果
correct = predicted.eq(labels).cpu().sum() # 计算本batch_size的正确数量
total_correct += correct # 记录正确数量
if (i + 1) % 10 == 0: # 每10个batch_size打印一次
print('tested: {}/{}'.format(tested_num, test_num))
print('Loss: {:.4f}, Accuracy: {:.2f}%'.format(loss.item(), 100 * correct / images.size(0)))
print('-' * 30)
end_time = time.time() # 记录结束时间
print('Each test_epoch take time: {} s'.format(end_time - begin_time))
print('This test_epoch accuracy: {:.2f}%'.format(100 * total_correct / test_num))
# 保存模型
torch.save(net.state_dict(),
os.path.join(model_path,
time.strftime("%Y%m%d-%H-%M-", time.localtime()) +
net_name + '.pkl')) # 按结束时间和网络类型保存模型
print('Finished Training')
if __name__ == '__main__':
main()
运行截图
C:\Users\DY\.conda\envs\torch\python.exe E:\AI_test\image_classification\lib\train.py
Please choose a network:
1. ResNet18
2. VGG
2
You have chosen the VGG network, start training.
----------------------------------------------------------------------------------------------------
Epoch 1/100
trained: 6400/50000
Loss: 2.3968, Accuracy: 11.72%
------------------------------
trained: 12800/50000
Loss: 2.2981, Accuracy: 15.62%
------------------------------
trained: 19200/50000
Loss: 2.2859, Accuracy: 20.31%
------------------------------
trained: 25600/50000
Loss: 2.0186, Accuracy: 21.09%
------------------------------
trained: 32000/50000
Loss: 2.0359, Accuracy: 19.53%
------------------------------
trained: 38400/50000
Loss: 1.8556, Accuracy: 28.12%
------------------------------
trained: 44800/50000
Loss: 1.9798, Accuracy: 28.91%
------------------------------
Each train_epoch take time: 70.0820574760437 s
This train_epoch accuracy: 20.70%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.6848, Accuracy: 17.19%
------------------------------
tested: 2560/10000
Loss: 2.0882, Accuracy: 10.16%
------------------------------
tested: 3840/10000
Loss: 2.6106, Accuracy: 0.00%
------------------------------
tested: 5120/10000
Loss: 2.2245, Accuracy: 13.28%
------------------------------
tested: 6400/10000
Loss: 1.9835, Accuracy: 31.25%
------------------------------
tested: 7680/10000
Loss: 2.0740, Accuracy: 20.31%
------------------------------
tested: 8960/10000
Loss: 1.6959, Accuracy: 45.31%
------------------------------
Each test_epoch take time: 58.377259969711304 s
This test_epoch accuracy: 28.67%
----------------------------------------------------------------------------------------------------
Epoch 2/100
trained: 6400/50000
Loss: 1.9043, Accuracy: 32.03%
------------------------------
trained: 12800/50000
Loss: 1.8649, Accuracy: 34.38%
------------------------------
trained: 19200/50000
Loss: 1.7617, Accuracy: 32.81%
------------------------------
trained: 25600/50000
Loss: 1.7857, Accuracy: 34.38%
------------------------------
trained: 32000/50000
Loss: 1.8286, Accuracy: 32.81%
------------------------------
trained: 38400/50000
Loss: 1.7217, Accuracy: 35.94%
------------------------------
trained: 44800/50000
Loss: 1.6121, Accuracy: 35.94%
------------------------------
Each train_epoch take time: 64.38722896575928 s
This train_epoch accuracy: 32.62%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.2206, Accuracy: 53.91%
------------------------------
tested: 2560/10000
Loss: 1.7993, Accuracy: 42.97%
------------------------------
tested: 3840/10000
Loss: 2.2702, Accuracy: 16.41%
------------------------------
tested: 5120/10000
Loss: 2.1621, Accuracy: 28.12%
------------------------------
tested: 6400/10000
Loss: 2.2219, Accuracy: 23.44%
------------------------------
tested: 7680/10000
Loss: 2.9674, Accuracy: 6.25%
------------------------------
tested: 8960/10000
Loss: 1.2704, Accuracy: 64.84%
------------------------------
Each test_epoch take time: 58.528542280197144 s
This test_epoch accuracy: 32.62%
----------------------------------------------------------------------------------------------------
Epoch 3/100
trained: 6400/50000
Loss: 1.6951, Accuracy: 39.06%
------------------------------
trained: 12800/50000
Loss: 1.6095, Accuracy: 46.88%
------------------------------
trained: 19200/50000
Loss: 1.7926, Accuracy: 32.03%
------------------------------
trained: 25600/50000
Loss: 1.7182, Accuracy: 33.59%
------------------------------
trained: 32000/50000
Loss: 1.5659, Accuracy: 46.09%
------------------------------
trained: 38400/50000
Loss: 1.6522, Accuracy: 35.94%
------------------------------
trained: 44800/50000
Loss: 1.5192, Accuracy: 46.88%
------------------------------
Each train_epoch take time: 65.94979453086853 s
This train_epoch accuracy: 40.94%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.1895, Accuracy: 46.88%
------------------------------
tested: 2560/10000
Loss: 1.7004, Accuracy: 37.50%
------------------------------
tested: 3840/10000
Loss: 1.7696, Accuracy: 24.22%
------------------------------
tested: 5120/10000
Loss: 1.4304, Accuracy: 57.03%
------------------------------
tested: 6400/10000
Loss: 1.8734, Accuracy: 43.75%
------------------------------
tested: 7680/10000
Loss: 2.4333, Accuracy: 14.06%
------------------------------
tested: 8960/10000
Loss: 1.4078, Accuracy: 59.38%
------------------------------
Each test_epoch take time: 61.93984866142273 s
This test_epoch accuracy: 44.08%
----------------------------------------------------------------------------------------------------
Epoch 4/100
trained: 6400/50000
Loss: 1.7454, Accuracy: 35.94%
------------------------------
trained: 12800/50000
Loss: 1.5639, Accuracy: 41.41%
------------------------------
trained: 19200/50000
Loss: 1.5122, Accuracy: 35.16%
------------------------------
trained: 25600/50000
Loss: 1.4079, Accuracy: 52.34%
------------------------------
trained: 32000/50000
Loss: 1.5722, Accuracy: 43.75%
------------------------------
trained: 38400/50000
Loss: 1.3767, Accuracy: 49.22%
------------------------------
trained: 44800/50000
Loss: 1.3190, Accuracy: 50.00%
------------------------------
Each train_epoch take time: 67.17029404640198 s
This train_epoch accuracy: 46.30%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.1679, Accuracy: 63.28%
------------------------------
tested: 2560/10000
Loss: 1.5420, Accuracy: 41.41%
------------------------------
tested: 3840/10000
Loss: 1.7648, Accuracy: 33.59%
------------------------------
tested: 5120/10000
Loss: 1.6371, Accuracy: 35.94%
------------------------------
tested: 6400/10000
Loss: 0.9881, Accuracy: 68.75%
------------------------------
tested: 7680/10000
Loss: 1.6172, Accuracy: 48.44%
------------------------------
tested: 8960/10000
Loss: 1.0041, Accuracy: 72.66%
------------------------------
Each test_epoch take time: 59.76149272918701 s
This test_epoch accuracy: 49.93%
----------------------------------------------------------------------------------------------------
Epoch 5/100
trained: 6400/50000
Loss: 1.3538, Accuracy: 51.56%
------------------------------
trained: 12800/50000
Loss: 1.5964, Accuracy: 40.62%
------------------------------
trained: 19200/50000
Loss: 1.3894, Accuracy: 47.66%
------------------------------
trained: 25600/50000
Loss: 1.3949, Accuracy: 46.88%
------------------------------
trained: 32000/50000
Loss: 1.6700, Accuracy: 46.88%
------------------------------
trained: 38400/50000
Loss: 1.2755, Accuracy: 59.38%
------------------------------
trained: 44800/50000
Loss: 1.4239, Accuracy: 47.66%
------------------------------
Each train_epoch take time: 66.46004009246826 s
This train_epoch accuracy: 49.61%
------------------------------------------------------------
tested: 1280/10000
Loss: 0.8314, Accuracy: 75.00%
------------------------------
tested: 2560/10000
Loss: 1.7140, Accuracy: 33.59%
------------------------------
tested: 3840/10000
Loss: 1.5265, Accuracy: 25.78%
------------------------------
tested: 5120/10000
Loss: 1.2060, Accuracy: 65.62%
------------------------------
tested: 6400/10000
Loss: 1.2976, Accuracy: 60.16%
------------------------------
tested: 7680/10000
Loss: 2.1083, Accuracy: 25.78%
------------------------------
tested: 8960/10000
Loss: 1.2824, Accuracy: 62.50%
------------------------------
Each test_epoch take time: 58.374892711639404 s
This test_epoch accuracy: 49.11%
----------------------------------------------------------------------------------------------------
Epoch 6/100
看得出,经过了5个epoch的训练,模型在逐步收敛。