【深度学习】使用PyTorch实现图像分类+所有代码+详细注释


使用PyTorch实现图像分类

本文将介绍如何使用PyTorch实现利用神经网络在图像数据集上进行训练和如何利用训练好的模型对图像进行分类


创建文件夹,用于保存训练好的网络

import os
if not os.path.exists("./save_model_rs_dataset"):
    os.mkdir("./save_model_rs_dataset")
复制代码

1. 定义模型


1.1 一个小的神经网络

在这里插入图片描述

import torch
from torch import nn
class MyNet(nn.Module):

    def __init__(self, num_classes=10) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, class_nums),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.model(x)
        return x
复制代码

1.2 AlxeNet网络结构

在这里插入图片描述

import torch
import torch.nn as nn
class MyNet(nn.Module):
    def __init__(self, num_classes=10):
        super(MyNet, self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
            nn.Conv2d(in_channels=96, out_channels=192, kernel_size=5, stride=1, padding=2, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
            nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256 * 6 * 6, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )

    def forward(self, x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x
复制代码

1.3 VGG16网络结构

在这里插入图片描述

# 作者 : 冷芝士鸭
import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self, num_classes):
        super(MyNet, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1),
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1),
        )
        self.block6 = nn.Sequential(
            nn.Flatten(),
            # 使用自适应池化
            
            nn.Linear(in_features=512 * 7 * 7, out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=4096, out_features=num_classes),
        )

    def forward(self, input):
        output = self.block1(input)
        output = self.block2(output)
        output = self.block3(output)
        output = self.block4(output)
        output = self.block5(output)
        output = self.block6(output)
        return output
复制代码

2. 加载数据集

import torchvision.datasets
import numpy as np
from torchvision import datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from torchvision.transforms import transforms


data_transform = transforms.Compose([
    transforms.Resize([224, 224]),    # 缩放图像大小为 224*224,第一个网络需要的输入尺寸是32*32
    transforms.ToTensor()     # 仅对数据做转换为 tensor 格式操作
])

# 每次取多少张图象进行训练
Batch_size = 128

# 使用自己的数据集
train_dataset = datasets.ImageFolder(root='../input/satellite-image-classification/train',transform=data_transform)
# 使用官方数据集
# train_dataset = torchvision.datasets.CIFAR10("dataset", train=True, transform=data_transform, download=True)
train_dataloader = DataLoader(dataset=train_dataset,batch_size=Batch_size,shuffle=True,num_workers=2)

test_dataset = datasets.ImageFolder(root='../input/satellite-image-classification/test',transform=data_transform)
# test_dataset = torchvision.datasets.CIFAR10("dataset", train=False, transform=data_transform, download=True)
test_dataloader = DataLoader(dataset=test_dataset,batch_size=Batch_size,shuffle=True,num_workers=2)

# 长度 = 数据集个数 / batch_size
# print(len(train_dataloader))

# 获取数据集类别数量
classes = test_dataset.classes

# 初始化混淆矩阵
cnf_matrix = np.zeros([len(classes), len(classes)])
复制代码

==说明:自己的数据集结构应该和下面一致(val可以不用),每个文件夹下是各个类别的图像,文件夹名即为类别==

在这里插入图片描述 数据集问题可参考7.2 将图像数据划分为训练集、测试集、验证集


图中数据集可以自行下载:Satellite Image Classification数据集


设置设备

# 如果GPU可用,利用GPU进行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
复制代码

创建网络

# 实例化网络
net = MyNet(num_classes=len(classes)).to(device)
复制代码

3. 定义训练参数

from torch.optim import lr_scheduler

# 4. 损失函数
loss_fn = nn.CrossEntropyLoss()


# 学习率
learning_rate = 0.001
# 5. 优化器
# 定义优化器(SGD:随机梯度下降)
# optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# 学习率衰减⽅法:学习率每隔 step_size 个 epoch 变为原来的 gamma
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)



# 训练轮数
epoch = 100

# 保存训练过程中的loss和精度
train_acc_lst, test_acc_lst = [], []
train_loss_lst, tset_loss_lst = [], []

# 记录训练过程中最大的精度
max_train_acc = 0
max_test_acc = 0
复制代码

通道转换函数

import numpy as np
# 单通道转为三通道
def transfer_channel(image):
    image = np.array(image)
    image = image.transpose((1, 0, 2, 3))             # array 转置
    image = np.concatenate((image, image, image), axis=0)
    image = image.transpose((1, 0, 2, 3))     # array 转置回来
    image = torch.tensor(image)               # 将 numpy 数据格式转为 tensor
    return image
复制代码

计算精度和loss函数

def compute_accuracy_and_loss(model, dataset, data_loader, device):
    correct, total = .0, .0
    for i, (features, targets) in enumerate(data_loader):
        # 通道转换
        if features.size(1) == 1:
            features = transfer_channel(features)
        features = features.to(device)
        targets = targets.to(device)
        output = model(features)
        currnet_loss = loss_fn(output, targets)
        # 求预测结果精确度之和
        # argmax:求最大值的下标,1按行求,0按列求
#         correct += (output.argmax(1) == targets).sum()
        
        _, predicted_labels = torch.max(output, 1)
        correct += (predicted_labels == targets).sum()
        
        # 更新混淆矩阵数据
        for idx in range(len(targets)):
            cnf_matrix[targets[idx]][predicted_labels[idx]] += 1
        
        total += targets.size(0)
        
    return float(correct) * 100 / len(dataset), currnet_loss.item()
复制代码

4. 训练

import time
start_time = time.time()

print(net)

for i in range(epoch):
    print("---------开始第{}轮训练,本轮学习率为:{}---------".format((i + 1), lr_scheduler.get_last_lr()))
    # 记录每轮训练批次数,每100次进行一次输出
    count_train = 0
    
    # 训练步骤开始
    net.train() # 将网络设置为训练模式,当网络包含 Dropout, BatchNorm时必须设置,其他时候无所谓
    for (features, targets) in train_dataloader:
        # 通道转换
        if features.size(1) == 1:
            features = transfer_channel(features)
        # 将图像和标签移动到指定设备上
        features = features.to(device)
        targets = targets.to(device)
        
        # 梯度清零,也就是把loss关于weight的导数变成0.
        # 进⾏下⼀次batch梯度计算的时候,前⼀个batch的梯度计算结果,没有保留的必要了。所以在下⼀次梯度更新的时候,先使⽤optimizer.zero_grad把梯度信息设置为0。
        optimizer.zero_grad()
        
        # 获取网络输出
        output = net(features)
        
        # 获取损失
        loss = loss_fn(output, targets)
        
        # 反向传播
        loss.backward()
        # 训练
        optimizer.step()
        # 纪录训练次数
        count_train += 1
        # item()函数会直接输出值,比如tensor(5),会输出5
        if count_train % 100 == 0:
            # 记录时间
            end_time = time.time()
            print(f"训练批次{count_train}/{len(train_dataloader)},loss:{loss.item():.3f},用时:{(end_time - start_time):.2f}" )

    # 将网络设置为测试模式,当网络包含 Dropout, BatchNorm时必须设置,其他时候无所谓
    net.eval()
    with torch.no_grad():
        # 计算训练精度
        train_accuracy, train_loss = compute_accuracy_and_loss(net, train_dataset, train_dataloader, device=device)
        # 更新最高精度
        if train_accuracy > max_train_acc[1]:
            max_train_acc[0] = i
            max_train_acc[1] = train_accuracy
        
        # 计算测试精度
        test_accuracy, test_loss = compute_accuracy_and_loss(net, test_dataset, test_dataloader, device=device)
        # 更新最高精度
        if test_accuracy > max_test_acc[1]:
            max_test_acc[0] = i
            max_test_acc[1] = test_accuracy
        
        # 收集训练过程精度和loss
        train_loss_lst.append(train_loss)
        train_acc_lst.append(train_accuracy)
        tset_loss_lst.append(test_loss)
        test_acc_lst.append(test_accuracy)
        
        print(f'Epoch: {i + 1:03d}/{epoch:03d}')
        print(f'Train Loss.: {train_loss:.2f}' f' | Validation Loss.: {test_loss:.2f}')
        print(f'Train Acc.: {train_accuracy:.2f}%' f' | Validation Acc.: {test_accuracy:.2f}%')

    # 训练计时
    elapsed = (time.time() - start_time) / 60
    print(f'本轮训练累计用时: {elapsed:.2f} min')

    # 保存达标的训练的模型
    if test_accuracy > 80:
        torch.save(net.state_dict(), "save_model_rs_dataset/train_model_{}.pth".format(i))
        print("第{}次训练模型已保存".format(i + 1))
    
    # 更新学习率
    lr_scheduler.step()

print('DONE!')
复制代码

部分内容参考自:VGG16识别MNIST数据集(Pytorch实战)


输出(以下均以AlexNet为例)

# 网络结构
MyNet(
  (feature_extraction): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2), bias=False)
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=4, bias=True)
  )
)
复制代码

训练过程输出

---------开始第1轮训练,本轮学习率为:[0.001]---------
Epoch: 001/050
Train Loss.: 0.64 | Validation Loss.: 0.60
Train Acc.: 62.09% | Validation Acc.: 63.93%
本轮训练累计用时: 0.61 min
---------开始第2轮训练,本轮学习率为:[0.001]---------
Epoch: 002/050
Train Loss.: 0.76 | Validation Loss.: 0.64
Train Acc.: 66.24% | Validation Acc.: 66.79%
本轮训练累计用时: 1.03 min
---------开始第3轮训练,本轮学习率为:[0.001]---------
Epoch: 003/050
Train Loss.: 0.63 | Validation Loss.: 0.68
Train Acc.: 57.81% | Validation Acc.: 60.71%
本轮训练累计用时: 1.44 min
......
复制代码

5. 显示Loss和Acc

5.1 使用plot

import matplotlib.pyplot as plt


plt.figure(dpi=480,figsize=(12,5))

# 训练损失和测试损失关系图
plt.plot(range(1, epoch + 1), train_loss_lst, label='Training loss')
plt.plot(range(1, epoch + 1), tset_loss_lst, label='Validation loss')
plt.legend(loc='upper right')
plt.ylabel('Cross entropy')
plt.xlabel('Epoch')
plt.show()


plt.figure(dpi=480,figsize=(12,5))
# 训练精度和测试精度关系图
plt.plot(range(1, epoch + 1), train_acc_lst, label='Training accuracy')
plt.plot(range(1, epoch + 1), test_acc_lst, label='Validation accuracy')
plt.legend(loc='upper left')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.show()


print("最大训练精度为:", max_train_acc)
print("最大测试精度为:", max_test_acc)
复制代码

部分内容参考自:VGG16识别MNIST数据集(Pytorch实战)

在这里插入图片描述

在这里插入图片描述

最大训练精度: [48, 87.82165039929015] 最大测试精度: [28, 89.64285714285714]


5.2 使用混淆矩阵

import itertools
import matplotlib.pyplot as plt
import numpy as np


# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#         print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
#         print(cm)
#     else:
#         print('显示具体数字:')
#         print(cm)
    plt.figure(dpi=320,figsize=(16,16))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
    plt.ylim(len(classes) - 0.5, -0.5)
    # fmt = '.2f' if normalize else 'd'
    fmt = '.2f'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()


# 第一种情况:显示百分比
plot_confusion_matrix(cnf_matrix, classes=classes, normalize=True, title='Normalized confusion matrix')

# 第二种情况:显示数字
plot_confusion_matrix(cnf_matrix, classes=classes, normalize=False, title='Normalized confusion matrix')
复制代码

参考自:Matplotlib绘制混淆矩阵

输出

在这里插入图片描述 在这里插入图片描述
百分比形式 数字形式

6. 验证训练的模型

加载上述训练过程中效果较好的一个网络进行验证

# 时间 : 2022/5/14 19:59
# 作者 : 冷芝士鸭
from PIL import features
from torch.utils.data import DataLoader

import torch
import torchvision
from torchvision import datasets

from torchvision.transforms import transforms

import matplotlib.pyplot as plt

# 对图像进行尺寸变换,因为网络要求的输入是64*64,并且是tensor类型
custom_transform = transforms.Compose([transforms.Resize([224, 224]),
                                       transforms.ToTensor()])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torchvision.models.vgg16().to(device)
# map_location:指定设备,cpu或者GPU
model.load_state_dict(torch.load("./save_model_rs_dataset/vgg16_train_model_38.pth", map_location="cpu"))

val_dataset = datasets.ImageFolder(
    root=r'E:\machine learning\Deep_learning\deep_learning\PyTorch\code\some_models\vgg-demo\VGG16\satelite\Satellite_Image_Classification\val',
    transform=custom_transform
)
classes = val_dataset.classes
val_loader = DataLoader(dataset=val_dataset,
                        batch_size=16,
                        shuffle=True)

for features, targets in val_loader:
    predictions = model.forward(features.to(device))
    predictions = torch.argmax(predictions, dim=1)
    plt.figure(figsize=(15, 15))  # 设置窗口大小

    for i in range(len(features)):
        plt.subplot(4, 4, i + 1)
        plt.title("Prediction:{}\nTarget:{}".format(classes[predictions[i]], classes[targets[i]]))
        # 解决报错:Invalid shape (3, 224, 224) for image data
        # 问题产生的原因是由于matplotlib.pyplot 使用时传入的数组型或Tensor型参数应为 img=(224,224,3)这种类型。
        # 其中img[0],img[1]为数组或张量的长与宽,img[2]为维度,如‘RPG’为3
        img = features[i].swapaxes(0, 1)
        img = img.swapaxes(1, 2)
        plt.imshow(img)
        # 关闭坐标轴
        plt.axis('off')

    plt.show()
    break
复制代码

验证结果 在这里插入图片描述


7. 问题与解决

7.1 图像尺寸问题

引自:CNN02:Pytorch实现VGG16的CIFAR10分类

一直以来进入了一个误区,一直以为数据图像的大小要匹配/适应网络的输入大小。在LeNet中,网络输入大小为32x32,而MNIST数据集中的图像大小为28x28,当时认为要使两者的大小匹配,将padding设置为2即解决了这个问题。然而,当用VGG训练CIFAR10数据集时,网络输入大小为224x224,而数据大小是32x32,这两者该怎么匹配呢?试过将32用padding的方法填充到224x224,但是运行之后显示内存不足 (笑哭.jpg)。也百度到将数据图像resize成224x224。

这个问题一直困扰了好久,看着代码里没有改动数据尺寸和网络的尺寸,不知道是怎么解决的这个匹配/适应的问题。最后一步步调试才发现在第一个全连接处报错,全连接的输入尺寸和设定的尺寸不一致,再回过头去一步步推数据的尺寸变化,发现原来的VGG网络输入是224x224的,由于卷积层不改变图像的大小,只有池化层才使图像大小缩小一半,所以经过5层卷积池化之后,图像大小缩小为原来的1/32。卷积层的最终输出是7x7x512=25088,所以全连接层的输入设为25088。

当输入图像大小为32x32时,经过5层卷积之后,图像大小缩小为1x1x512,全连接的输入大小就变为了512,所以不匹配的地方在这里,而不是网络的输入处。所以输入的训练图像的大小不必要与网络原始的输入大小一致,只需要计算经过卷积池化后最终的输出(也即全连接层的输入),然后改以下全连接的输入即可。


7.2 将图像数据划分为训练集、测试集、验证集

现有数据集如下图,但是没有划分为训练集和测试集,使用下面代码可以进行数据集划分 在这里插入图片描述

dataset
├─cloudy
├─desert
├─green_area
└─water
复制代码
import os
import random
import shutil
from shutil import copy2


def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.1, test_scale=0.1):
   '''
   读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行
   :param src_data_folder: 源文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data
   :param target_data_folder: 目标文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data
   :param train_scale: 训练集比例
   :param val_scale: 验证集比例
   :param test_scale: 测试集比例
   :return:
   '''
   print("开始数据集划分")
   class_names = os.listdir(src_data_folder)
   # 在目标目录下创建文件夹
   split_names = ['train', 'val', 'test']
   for split_name in split_names:
       split_path = os.path.join(target_data_folder, split_name)
       if os.path.isdir(split_path):
           pass
       else:
           os.mkdir(split_path)
       # 然后在split_path的目录下创建类别文件夹
       for class_name in class_names:
           class_split_path = os.path.join(split_path, class_name)
           if os.path.isdir(class_split_path):
               pass
           else:
               os.mkdir(class_split_path)

   # 按照比例划分数据集,并进行数据图片的复制
   # 首先进行分类遍历
   for class_name in class_names:
       current_class_data_path = os.path.join(src_data_folder, class_name)
       current_all_data = os.listdir(current_class_data_path)
       current_data_length = len(current_all_data)
       current_data_index_list = list(range(current_data_length))
       random.shuffle(current_data_index_list)

       train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
       val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
       test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
       train_stop_flag = current_data_length * train_scale
       val_stop_flag = current_data_length * (train_scale + val_scale)
       current_idx = 0
       train_num = 0
       val_num = 0
       test_num = 0
       for i in current_data_index_list:
           src_img_path = os.path.join(current_class_data_path, current_all_data[i])
           if current_idx <= train_stop_flag:
               copy2(src_img_path, train_folder)
               # print("{}复制到了{}".format(src_img_path, train_folder))
               train_num = train_num + 1
           elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
               copy2(src_img_path, val_folder)
               # print("{}复制到了{}".format(src_img_path, val_folder))
               val_num = val_num + 1
           else:
               copy2(src_img_path, test_folder)
               # print("{}复制到了{}".format(src_img_path, test_folder))
               test_num = test_num + 1

           current_idx = current_idx + 1

       print("*********************************{}*************************************".format(class_name))
       print(
           "{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale, current_data_length))
       print("训练集{}:{}张".format(train_folder, train_num))
       print("验证集{}:{}张".format(val_folder, val_num))
       print("测试集{}:{}张".format(test_folder, test_num))


if __name__ == '__main__':
   src_data_folder = r"原始数据集路径" # 如E:\深度学习\猫狗数据集下有dog和cat两个分好类的文件夹路径写为 'E:\深度学习\猫狗数据集'
   target_data_folder = r"划分好要放在那个文件夹下" # 如 'E:\深度学习\划分后的猫狗数据集'
   data_set_split(src_data_folder, target_data_folder)
复制代码

划分完后,E:\深度学习\划分后的猫狗数据集下会自动生成三个划分后的文件夹

在这里插入图片描述

dataset_split
├─test
│  ├─cloudy
│  ├─desert
│  ├─green_area
│  └─water
├─train
│  ├─cloudy
│  ├─desert
│  ├─green_area
│  └─water
└─val
    ├─cloudy
    ├─desert
    ├─green_area
    └─water
复制代码

参考自:数据集切分(训练,验证,测试)


猜你喜欢

转载自juejin.im/post/7111484032549912584