PyTorch从零构建网络实现supervisely 发布的人像分割数据集完美教程

写在最前面的最重要的一段话:

       经过调研大部分图像语义分割教程发现基本所有的教程内容都不够全面,有的教程只有搭建简单的网络实现语义分割训练功能;有的教程只有训练和预测部分;有的教程训练过程中只有损失函数没有精确度评价指标等等。鉴于此,本人经过调研多个图像语义分割教程并进行了详细的梳理总结,将图像语义分割所需要的知识点、数据集增强和训练技巧、多种评价指标、模型保存和转换等重要内容都进行了整合到该教程中,能够让初学者只需要学习这一个教程就能够完全掌握图像语义分割的基本功底。

整个工程文件和supervisely数据集已上传放到以下链接,需要的请自行下载学习使用:

https://mbd.pub/o/bread/Zp2ZkpZx

一、训练图像分割网络主要流程

  1. 构建图像语义分割数据训练集、验证集和测试集
  2. 数据预处理、包括数据增强、数据标准化和归一化
  3. 从零构建图像语义分割网络模型
  4. 设置训练超参数,学习率、优化器、损失函数等超参数
  5. 模型训练、验证、预测、保存模型、模型onnx转化

二、各个流程简要说明

1. 构建图像语义分割数据训练集、验证集、测试集

本文使用supervisely 发布的人像分割数据集,在上面链接已打包供下载学习使用:,部分数据及对应的mask如下图所示:

部分图像和mask标签示例

在工程目录下,①新建datasets文件夹用于存放训练集、验证集和测试集;②在datasets文件夹内分别新建images和labels文件夹,用来放图片和对应的mask图片;③在images和labels文件夹内分别新建train、val、test文件夹用于存放图像语义分割的训练集、验证集和测试集对应的图像和mask标签,结构如下:

图像语义分割数据集保存结构

2.加载数据集,包括数据预处理、数据增强、数据标准化和归一化

(1)导入相应的库和文件:

import os
import torch
import torch.nn as nn
from models.Simplify_Net import Simplify_Net
from utils.engine import train_and_val, plot_pix_acc, plot_miou, plot_loss, plot_lr
import argparse
import numpy as np
from utils.datasets import SegData
from utils.config import ALL_CLASSES, LABEL_COLORS_LIST
import albumentations as A

(2)数据预处理,将图像Resize到统一大小,之后训练集的数据增强再进行标准化,预处理之后的图片可以正常输入网络,对于训练集可以采取一些数据增强手段来增强网络的泛化能力,验证集不做数据增强。

train_transform = A.Compose([A.Resize(args.input_size[0], args.input_size[1]),
                                 A.HorizontalFlip(0.5),
                                 A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    val_transform = A.Compose([A.Resize(args.input_size[0], args.input_size[1]),
                               A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    train_dataset = SegData(image_path=os.path.join(args.data_path, 'images/train'),
                            mask_path=os.path.join(args.data_path, 'labels/train'),
                            all_classes=ALL_CLASSES,
                            label_colors_list=LABEL_COLORS_LIST,
                            data_transforms=train_transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers)

    val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/val'),
                          mask_path=os.path.join(args.data_path, 'labels/val'),
                          all_classes=ALL_CLASSES,
                          label_colors_list=LABEL_COLORS_LIST,
                          data_transforms=val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)

(3)加载数据集完整代码:

class SegData(Dataset):
    def __init__(self, image_path, mask_path, all_classes, label_colors_list, data_transforms=None):
        self.image_path = image_path
        self.mask_path = mask_path

        self.images = os.listdir(self.image_path)
        self.masks = os.listdir(self.mask_path)
        self.transform = data_transforms
        self.all_classes = all_classes
        self.label_colors_list = label_colors_list
        self.classes_to_train = all_classes
        # Convert string names to class values for masks.
        self.class_values = set_class_values(self.all_classes, self.classes_to_train)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_filename = self.images[idx]
        mask_filename = image_filename.replace('jpeg', 'png')
        image = np.array(Image.open(os.path.join(self.image_path, image_filename)).convert('RGB'))
        mask = np.array(Image.open(os.path.join(self.mask_path, mask_filename)).convert('RGB'))

        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
        image = transformed['image']
        mask = transformed['mask']
        # Get colored label mask.
        mask = get_label_mask(np.array(mask), self.class_values, self.label_colors_list)
        image = np.transpose(image, (2, 0, 1))
        image = torch.tensor(image, dtype=torch.float)
        mask = torch.tensor(mask, dtype=torch.long)

        return image, mask

3.图像语义分割网络模型从零构建

该语义分割网络模型采用了3层卷积下采样和3层上采样,并使用了conv1和conv2两层特征进行了两个简单的通道跳跃连接。网络模型结构如下图所示。

class Simplify_Net(nn.Module):
    def __init__(self, num_classes=2):
        super(Simplify_Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=2)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=2)
        self.bn3 = nn.BatchNorm2d(16)
        self.relu3 = nn.ReLU(inplace=True)

        self.upconv1 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=4, padding=1, stride=2)
        self.bn4 = nn.BatchNorm2d(16)
        self.relu4 = nn.ReLU(inplace=True)

        self.upconv2 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=4, padding=1, stride=2)
        self.bn5 = nn.BatchNorm2d(16)
        self.relu5 = nn.ReLU(inplace=True)

        self.conv_last = nn.Conv2d(in_channels=32, out_channels=num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        x1 = self.relu1(self.bn1(self.conv1(x)))
        x2 = self.relu2(self.bn2(self.conv2(x1)))
        x3 = self.relu3(self.bn3(self.conv3(x2)))

        up1 = torch.cat([x2, self.relu4(self.bn4(self.upconv1(x3)))], dim=1)
        up2 = torch.cat([x1, self.relu5(self.bn5(self.upconv2(up1)))], dim=1)
        up3 = self.conv_last(up2)

        out = interpolate(up3, scale_factor=2, mode='bilinear', align_corners=False)

        return out


if __name__ == '__main__':
    img = torch.randn(2, 3, 224, 224)
    net = Simplify_Net()
    sample = net(img)
    print(sample.shape)
网络模型结构

4.图像语义分割网络模型训练

训练过程中根据训练损失值保存当前最优的模型。

    model = Simplify_Net(args.nb_classes)
    loss_function = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)

    history = train_and_val(args.epochs, model, train_loader, val_loader, loss_function, optimizer, scheduler, args.output_dir, device, args.nb_classes)

    plot_loss(np.arange(0, args.epochs), args.output_dir, history)
    plot_pix_acc(np.arange(0, args.epochs), args.output_dir, history)
    plot_miou(np.arange(0, args.epochs), args.output_dir, history)
    plot_lr(np.arange(0, args.epochs), args.output_dir, history)

训练结束保存训练损失曲线图、pix Acc、MIoU、学习率等重要指标

5.进行单张图片预测

首先加载保存好的训练模型到GPU,然后通过transforms制作单张图片测试数据格式作为模型预测输入,最后将模型预测输出进行mask判断和保存为图像格式。从模型预测的分割结果可以看出这个简单的图像语义分割模型对明显的人物像素区域可以分割正确,对背景一致的像素区域难以区分,需要对分割模型进行进一步的优化设计以实现更复杂像素的正确分割。

    model = Simplify_Net(args.nb_classes)
    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)
    model.to(device)
    model.eval()

    input_tensor = transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        pred = output.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8)

    mask = Image.fromarray(pred*255)
    out = mask.resize(img_size)
    out.save("result.png")
单张图片预测结果

6.模型整体评价

首先制作需要评价模型的数据集,即提前准备好的测试集作为模型的整体评价。

    classes = ALL_CLASSES
    val_transform = A.Compose([A.Resize(args.input_size[0],args.input_size[1]),
                               A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/test'),
                          mask_path=os.path.join(args.data_path, 'labels/test'),
                          all_classes=classes,
                          label_colors_list=LABEL_COLORS_LIST,
                          data_transforms=val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)

加载保存好的训练模型并加载到GPU,平均MiOU 0.65655,Pixel Accuracy 0.8444

    model = Simplify_Net(args.nb_classes)
    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.to(device)
    model.eval()

    with torch.no_grad():
        with tqdm(total=len(val_loader)) as pbar:
            for image, label in val_loader:
                output = model(image.to(device))
                pred = output.data.cpu().numpy()
                label = label.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                segmetric.add_batch(label, pred)
                pbar.update(1)

    pix_acc = segmetric.Pixel_Accuracy()
    every_iou, miou = segmetric.Mean_Intersection_over_Union()
模型整体评价结果

7.模型onnx转换

将训练好的模型转换为onnx格式,以进行后续其他深度学习框架和平台的应用。ONNX(Open Neural Network Exchange,开放式神经网络交换格式)是一种模型文件格式,它在模型训练和模型推理中间提供了中间桥梁,使得上游不同的训练框架都能导出ONNX格式的模型,给到下游不同的推理框架都可以读取ONNX进行部署。

onnx模型中间件示意图

这种基于ONNX的模型训练,中间件,再到模型推理的方式使得:

  • ONNX将模型训练和推理解耦,任意上游训练框架和下游推理框架都可以组合搭配,而不需要用同一种框架既进行训练又进行推理
  • ONNX是通用的模型格式,不同训练框架输出的模型可以用ONNX作为桥梁进行转换,使得模型更方便迁移
  • ONNX部署兼容性极强,支持多种推理框架,支持CPU/GPU推理,支持跨语言推理
  • ONNX格式配合上类似ONNXRumtime等推理框架,相比于模型在原生环境的推理性能会有大幅的提升
    def main(args):
    x = torch.randn(1, 3, args.input_size[0], args.input_size[1])  # 随机生成一幅图像数据 3表示通道 图像尺寸args.input_size[0]✖args.input_size[1]
    input_names = ["input"]
    out_names = ["output"]

    model = Simplify_Net(args.nb_classes)
    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)
    model.eval()

    torch.onnx.export(model, x, args.weights.replace('pth', 'onnx'), export_params=True, verbose=True,
                      input_names=input_names, output_names=out_names)
    print('please run: python -m onnxsim best_model.onnx best_sim.onnx\n')

onnx模型转换和模型转换后进行优化结果:

ONNX模型转换成功并对模型进行了优化

下图为转换前的last_model.pth和转换后的best_model.onnx模型结构对比

三、模型优化改进分割效果

可以看出上述简单的三层卷积+跳跃链接实现的人像分割效果很一般,因此尝试通过各种优化技巧改进最终的人像分割效果,这里尝试首先通过优化网络模型实现改进人像分割的效果。如下图通过对网络的卷积深度增加以及加入最大值下采样池化和最大值上采样池化操作,最终模型结构如下图所示:

可以看出经过优化后的模型准确率有明细的提升!

优化后的分割结果对比如下图所示:

未完待续……

可加微信有偿指导环境配置安装和代码讲解

猜你喜欢

转载自blog.csdn.net/lzdjlu/article/details/143028810