一个简单的图像分类项目(七)编写脚本:完整的训练脚本

完善后的train.py: 

import time

from load_imags import train_loader, train_num, test_loader, test_num
from nets import *


def main():
    # 定义网络
    print('Please choose a network:')
    print('1. ResNet18')
    print('2. VGG')

    # 选择网络
    while True:
        net_choose = input('')
        if net_choose == '1':
            net = resnet18_model().to(device)
            net_name = 'ResNet18'
            print('You have chosen the ResNet18 network, start training.')
            break
        elif net_choose == '2':
            net = vgg_model().to(device)
            net_name = 'VGG-simple'
            print('You have chosen the VGG network, start training.')
            break
        else:
            print('Please input a correct number!')

    # 定义损失函数和优化器
    loss_func = nn.CrossEntropyLoss()  # 交叉熵损失函数
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)  # 优化器使用Adam
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5,
                                                gamma=0.9)  # 学习率衰减, 每5个epoch,学习率乘以0.9

    # 训练模型
    for epoch in range(num_epoches):
        trained_num = 0  # 记录训练过的图片数量
        total_correct = 0  # 记录正确数量
        print('-' * 100)
        print('Epoch {}/{}'.format(epoch + 1, num_epoches))
        begin_time = time.time()  # 记录开始时间
        net.train()  # 训练模式
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)  # 每batch_size个图像的数据
            labels = labels.to(device)  # 每batch_size个图像的标签
            trained_num += images.size(0)  # 记录训练过的图片数量
            outputs = net(images)  # 前向传播
            loss = loss_func(outputs, labels)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 优化器更新参数

            _, predicted = torch.max(outputs.data, 1)  # 预测结果
            correct = predicted.eq(labels).cpu().sum()  # 计算本batch_size的正确数量
            total_correct += correct  # 记录正确数量
            if (i + 1) % 50 == 0:  # 每50个batch_size打印一次
                print('trained: {}/{}'.format(trained_num, train_num))
                print('Loss: {:.4f}, Accuracy: {:.2f}%'.format(loss.item(), 100 * correct / images.size(0)))
                print('-' * 30)

        # 每5个epoch,学习率衰减
        scheduler.step()
        end_time = time.time()  # 记录结束时间
        print('Each train_epoch take time: {} s'.format(end_time - begin_time))
        print('This train_epoch accuracy: {:.2f}%'.format(100 * total_correct / train_num))
        print('-' * 60)

        tested_num = 0  # 记录测试过的图片数量
        total_correct = 0  # 记录正确数量
        begin_time = time.time()  # 记录开始时间
        net.eval()  # 测试模式
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)  # 每batch_size个图像的数据
            labels = labels.to(device)  # 每batch_size个图像的标签
            tested_num += images.size(0)  # 记录测试过的图片数量
            outputs = net(images)  # 前向传播
            loss = loss_func(outputs, labels)  # 计算损失

            _, predicted = torch.max(outputs.data, 1)  # 预测结果
            correct = predicted.eq(labels).cpu().sum()  # 计算本batch_size的正确数量
            total_correct += correct  # 记录正确数量
            if (i + 1) % 10 == 0:  # 每10个batch_size打印一次
                print('tested: {}/{}'.format(tested_num, test_num))
                print('Loss: {:.4f}, Accuracy: {:.2f}%'.format(loss.item(), 100 * correct / images.size(0)))
                print('-' * 30)

        end_time = time.time()  # 记录结束时间
        print('Each test_epoch take time: {} s'.format(end_time - begin_time))
        print('This test_epoch accuracy: {:.2f}%'.format(100 * total_correct / test_num))

    # 保存模型
     torch.save(net.state_dict(),
               os.path.join(model_path,
                            time.strftime("%Y%m%d-%H-%M-", time.localtime()) +
                            net_name + '.pkl'))  # 按结束时间和网络类型保存模型
    print('Finished Training')

if __name__ == '__main__':
    main()

 运行截图

C:\Users\DY\.conda\envs\torch\python.exe E:\AI_test\image_classification\lib\train.py 
Please choose a network:
1. ResNet18
2. VGG
2
You have chosen the VGG network, start training.
----------------------------------------------------------------------------------------------------
Epoch 1/100
trained: 6400/50000
Loss: 2.3968, Accuracy: 11.72%
------------------------------
trained: 12800/50000
Loss: 2.2981, Accuracy: 15.62%
------------------------------
trained: 19200/50000
Loss: 2.2859, Accuracy: 20.31%
------------------------------
trained: 25600/50000
Loss: 2.0186, Accuracy: 21.09%
------------------------------
trained: 32000/50000
Loss: 2.0359, Accuracy: 19.53%
------------------------------
trained: 38400/50000
Loss: 1.8556, Accuracy: 28.12%
------------------------------
trained: 44800/50000
Loss: 1.9798, Accuracy: 28.91%
------------------------------
Each train_epoch take time: 70.0820574760437 s
This train_epoch accuracy: 20.70%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.6848, Accuracy: 17.19%
------------------------------
tested: 2560/10000
Loss: 2.0882, Accuracy: 10.16%
------------------------------
tested: 3840/10000
Loss: 2.6106, Accuracy: 0.00%
------------------------------
tested: 5120/10000
Loss: 2.2245, Accuracy: 13.28%
------------------------------
tested: 6400/10000
Loss: 1.9835, Accuracy: 31.25%
------------------------------
tested: 7680/10000
Loss: 2.0740, Accuracy: 20.31%
------------------------------
tested: 8960/10000
Loss: 1.6959, Accuracy: 45.31%
------------------------------
Each test_epoch take time: 58.377259969711304 s
This test_epoch accuracy: 28.67%
----------------------------------------------------------------------------------------------------
Epoch 2/100
trained: 6400/50000
Loss: 1.9043, Accuracy: 32.03%
------------------------------
trained: 12800/50000
Loss: 1.8649, Accuracy: 34.38%
------------------------------
trained: 19200/50000
Loss: 1.7617, Accuracy: 32.81%
------------------------------
trained: 25600/50000
Loss: 1.7857, Accuracy: 34.38%
------------------------------
trained: 32000/50000
Loss: 1.8286, Accuracy: 32.81%
------------------------------
trained: 38400/50000
Loss: 1.7217, Accuracy: 35.94%
------------------------------
trained: 44800/50000
Loss: 1.6121, Accuracy: 35.94%
------------------------------
Each train_epoch take time: 64.38722896575928 s
This train_epoch accuracy: 32.62%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.2206, Accuracy: 53.91%
------------------------------
tested: 2560/10000
Loss: 1.7993, Accuracy: 42.97%
------------------------------
tested: 3840/10000
Loss: 2.2702, Accuracy: 16.41%
------------------------------
tested: 5120/10000
Loss: 2.1621, Accuracy: 28.12%
------------------------------
tested: 6400/10000
Loss: 2.2219, Accuracy: 23.44%
------------------------------
tested: 7680/10000
Loss: 2.9674, Accuracy: 6.25%
------------------------------
tested: 8960/10000
Loss: 1.2704, Accuracy: 64.84%
------------------------------
Each test_epoch take time: 58.528542280197144 s
This test_epoch accuracy: 32.62%
----------------------------------------------------------------------------------------------------
Epoch 3/100
trained: 6400/50000
Loss: 1.6951, Accuracy: 39.06%
------------------------------
trained: 12800/50000
Loss: 1.6095, Accuracy: 46.88%
------------------------------
trained: 19200/50000
Loss: 1.7926, Accuracy: 32.03%
------------------------------
trained: 25600/50000
Loss: 1.7182, Accuracy: 33.59%
------------------------------
trained: 32000/50000
Loss: 1.5659, Accuracy: 46.09%
------------------------------
trained: 38400/50000
Loss: 1.6522, Accuracy: 35.94%
------------------------------
trained: 44800/50000
Loss: 1.5192, Accuracy: 46.88%
------------------------------
Each train_epoch take time: 65.94979453086853 s
This train_epoch accuracy: 40.94%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.1895, Accuracy: 46.88%
------------------------------
tested: 2560/10000
Loss: 1.7004, Accuracy: 37.50%
------------------------------
tested: 3840/10000
Loss: 1.7696, Accuracy: 24.22%
------------------------------
tested: 5120/10000
Loss: 1.4304, Accuracy: 57.03%
------------------------------
tested: 6400/10000
Loss: 1.8734, Accuracy: 43.75%
------------------------------
tested: 7680/10000
Loss: 2.4333, Accuracy: 14.06%
------------------------------
tested: 8960/10000
Loss: 1.4078, Accuracy: 59.38%
------------------------------
Each test_epoch take time: 61.93984866142273 s
This test_epoch accuracy: 44.08%
----------------------------------------------------------------------------------------------------
Epoch 4/100
trained: 6400/50000
Loss: 1.7454, Accuracy: 35.94%
------------------------------
trained: 12800/50000
Loss: 1.5639, Accuracy: 41.41%
------------------------------
trained: 19200/50000
Loss: 1.5122, Accuracy: 35.16%
------------------------------
trained: 25600/50000
Loss: 1.4079, Accuracy: 52.34%
------------------------------
trained: 32000/50000
Loss: 1.5722, Accuracy: 43.75%
------------------------------
trained: 38400/50000
Loss: 1.3767, Accuracy: 49.22%
------------------------------
trained: 44800/50000
Loss: 1.3190, Accuracy: 50.00%
------------------------------
Each train_epoch take time: 67.17029404640198 s
This train_epoch accuracy: 46.30%
------------------------------------------------------------
tested: 1280/10000
Loss: 1.1679, Accuracy: 63.28%
------------------------------
tested: 2560/10000
Loss: 1.5420, Accuracy: 41.41%
------------------------------
tested: 3840/10000
Loss: 1.7648, Accuracy: 33.59%
------------------------------
tested: 5120/10000
Loss: 1.6371, Accuracy: 35.94%
------------------------------
tested: 6400/10000
Loss: 0.9881, Accuracy: 68.75%
------------------------------
tested: 7680/10000
Loss: 1.6172, Accuracy: 48.44%
------------------------------
tested: 8960/10000
Loss: 1.0041, Accuracy: 72.66%
------------------------------
Each test_epoch take time: 59.76149272918701 s
This test_epoch accuracy: 49.93%
----------------------------------------------------------------------------------------------------
Epoch 5/100
trained: 6400/50000
Loss: 1.3538, Accuracy: 51.56%
------------------------------
trained: 12800/50000
Loss: 1.5964, Accuracy: 40.62%
------------------------------
trained: 19200/50000
Loss: 1.3894, Accuracy: 47.66%
------------------------------
trained: 25600/50000
Loss: 1.3949, Accuracy: 46.88%
------------------------------
trained: 32000/50000
Loss: 1.6700, Accuracy: 46.88%
------------------------------
trained: 38400/50000
Loss: 1.2755, Accuracy: 59.38%
------------------------------
trained: 44800/50000
Loss: 1.4239, Accuracy: 47.66%
------------------------------
Each train_epoch take time: 66.46004009246826 s
This train_epoch accuracy: 49.61%
------------------------------------------------------------
tested: 1280/10000
Loss: 0.8314, Accuracy: 75.00%
------------------------------
tested: 2560/10000
Loss: 1.7140, Accuracy: 33.59%
------------------------------
tested: 3840/10000
Loss: 1.5265, Accuracy: 25.78%
------------------------------
tested: 5120/10000
Loss: 1.2060, Accuracy: 65.62%
------------------------------
tested: 6400/10000
Loss: 1.2976, Accuracy: 60.16%
------------------------------
tested: 7680/10000
Loss: 2.1083, Accuracy: 25.78%
------------------------------
tested: 8960/10000
Loss: 1.2824, Accuracy: 62.50%
------------------------------
Each test_epoch take time: 58.374892711639404 s
This test_epoch accuracy: 49.11%
----------------------------------------------------------------------------------------------------
Epoch 6/100

看得出,经过了5个epoch的训练,模型在逐步收敛。

猜你喜欢

转载自blog.csdn.net/xulibo5828/article/details/143287174