学习笔记|Pytorch使用教程32(图像分割一瞥)

学习笔记|Pytorch使用教程32

本学习笔记主要摘自“深度之眼”,做一个总结,方便查阅。
使用Pytorch版本为1.2

  • 图像分割是什么?
  • 模型是如何将图像分割的?
  • 深度学习图像分割模型简介
  • 训练Unet完成人像抠图

一.图像分割是什么?

图像分割:将图像每一个像素分类
在这里插入图片描述
1.超像素分割:少量超像素代替大量像素,常用于图像预处理
2. 语义分割:逐像素分类,无法区分个体
3. 实例分割:对个体目标进行分割,像素级目标检测
4. 全景分割:语义分割结合实例分割

在这里插入图片描述

二.模型是如何将图像分割的?

在这里插入图片描述

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":

    path_img = os.path.join(BASE_DIR, "demo_img1.png")
    # path_img = os.path.join(BASE_DIR, "demo_img2.png")
    # path_img = os.path.join(BASE_DIR, "demo_img3.png")

    # config
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 1. load data & model
    input_image = Image.open(path_img).convert("RGB")
    model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
    model.eval()

    # 2. preprocess
    input_tensor = preprocess(input_image)
    input_bchw = input_tensor.unsqueeze(0)

    # 3. to device
    if torch.cuda.is_available():
        input_bchw = input_bchw.to(device)
        model.to(device)

    # 4. forward
    with torch.no_grad():
        tic = time.time()
        print("input img tensor shape:{}".format(input_bchw.shape))
        output_4d = model(input_bchw)['out']
        output = output_4d[0]
        print("pass: {:.3f}s use: {}".format(time.time() - tic, device))
        print("output img tensor shape:{}".format(output.shape))
    output_predictions = output.argmax(0)

    # 5. visualization
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")

    # plot the semantic segmentation predictions of 21 classes in each color
    r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
    r.putpalette(colors)
    plt.subplot(121).imshow(r)
    plt.subplot(122).imshow(input_image)
    plt.show()

    # appendix
    classes = ['__background__',
                       'aeroplane', 'bicycle', 'bird', 'boat',
                       'bottle', 'bus', 'car', 'cat', 'chair',
                       'cow', 'diningtable', 'dog', 'horse',
                       'motorbike', 'person', 'pottedplant',
                       'sheep', 'sofa', 'train', 'tvmonitor']

输出:

input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 21.773s use: cpu
output img tensor shape:torch.Size([21, 433, 649])

21是表示可以分割21个类别,其中一个是背景类。
在这里插入图片描述
查看下一个类别:path_img = os.path.join(BASE_DIR, "demo_img2.png")
输出:

input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 20.287s use: cpu
output img tensor shape:torch.Size([21, 433, 649])

在这里插入图片描述
查看第三张图片:path_img = os.path.join(BASE_DIR, "demo_img3.png")
输出:

input img tensor shape:torch.Size([1, 3, 730, 574])
pass: 24.351s use: cpu
output img tensor shape:torch.Size([21, 730, 574])

在这里插入图片描述

三.深度学习图像分割模型简介

模型如何完成图像分割?

  • 答:图像分割由模型与人类配合完成
  • 模型:将数据映射到特征
  • 人类:定义特征的物理意义,解决实际问题
    在这里插入图片描述
    PyTorch-Hub——PyTorch模型库,有大量模型供开发者调用
    1.torch.hub.load(‘pytorch/vision’, ‘deeplabv3_resnet101’,pretrained=True)
    model = torch.hub.load(github, model, *args, **kwargs)
    功能:加载模型
    主要参数:
  • github:str, 项目名,eg:pytorch/vision<repo_owner/repo_name[:tag_name]>
  • model: str, 模型名

2.torch.hub.list(github, force_reload=False)
3.torch.hub.help(github, model, force_reload=False)

图像分割的思考
在这里插入图片描述在这里插入图片描述
Ps:蓝色为小猫,绿色为小狗

深度学习中的图像分割模型

Fully Convolutional Networks for Semantic Segmentation
最主要贡献:

  • 利用全卷积完成pixelwise prediction

在这里插入图片描述
U-Net: Convolutional Networks for Biomedical Image Segmentation
最主要贡献:

  • 奠定Unet系列分割模型的
  • 基本结构 ——编码器与解码器的特征融合
  • https://github.com/shawnbit/unet-family

在这里插入图片描述
在这里插入图片描述
DeepLabv1 Semantic image segmentation with deep convolutional nets and fully connected CRFs
DeepLab系列——V1
主要特点:

  • 孔洞卷积:借助孔洞卷积,增大感受野
  • CRF:采用CRF进行mask后处理

在这里插入图片描述

DeepLab- Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs
DeepLab系列——V2
主要特点:

  • ASPP(Atrous spatial pyramid pooling ):解决多尺度问题

在这里插入图片描述

DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation
DeepLab系列——V3
主要特点:

  • 1.孔洞卷积的串行
  • 2.ASPP的并行

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation
DeepLab系列——V3+
主要特点:

  • deeplabv3基础上加上Encoder-Decoder思想

在这里插入图片描述
Deep Semantic Segmentation of Natural and Medical Images: A Review》2019
在这里插入图片描述
图像分割资源:
https://github.com/shawnbit/unet-family
https://github.com/yassouali/pytorch_segmentation

四.训练Unet完成人像抠图

在这里插入图片描述

  • 数据来源:https://github.com/PetroWu/AutoPortraitMatting

测试代码:

# -*- coding: utf-8 -*-
"""
# @file name  : unet_portrait_matting.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-11-25
# @brief      : train unet
"""

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet
import random

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seed()  # 设置随机种子


def compute_dice(y_pred, y_true):
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


if __name__ == "__main__":

    # config
    LR = 0.01
    BATCH_SIZE = 8
    max_epoch = 1   # 400
    start_epoch = 0
    lr_step = 150
    val_interval = 3
    checkpoint_interval = 20
    vis_num = 10
    mask_thres = 0.5

    train_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "train")
    valid_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")

    # step 1
    train_set = PortraitDataset(train_dir)
    valid_set = PortraitDataset(valid_dir)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

    # step 2
    net = UNet(in_channels=3, out_channels=1, init_features=64)   # init_features is 64 in stander uent
    net.to(device)

    # step 3
    loss_fn = nn.MSELoss()
    # step 4
    optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)

    # step 5
    train_curve = list()
    valid_curve = list()
    train_dice_curve = list()
    valid_dice_curve = list()
    for epoch in range(start_epoch, max_epoch):

        train_loss_total = 0.
        train_dice_total = 0.

        net.train()
        for iter, (inputs, labels) in enumerate(train_loader):

            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)

            # forward
            outputs = net(inputs)

            # backward
            optimizer.zero_grad()
            loss = loss_fn(outputs, labels)
            loss.backward()

            optimizer.step()

            # print
            train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu())
            train_dice_curve.append(train_dice)
            train_curve.append(loss.item())

            train_loss_total += loss.item()

            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
                  "running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
                                    train_loss_total/(iter+1), train_dice, scheduler.get_lr()))

        scheduler.step()

        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint = {"model_state_dict": net.state_dict(),
                          "optimizer_state_dict": optimizer.state_dict(),
                          "epoch": epoch}
            path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
            torch.save(checkpoint, path_checkpoint)

        # validate the model
        if (epoch+1) % val_interval == 0:

            net.eval()
            valid_loss_total = 0.
            valid_dice_total = 0.

            with torch.no_grad():
                for j, (inputs, labels) in enumerate(valid_loader):
                    if torch.cuda.is_available():
                        inputs, labels = inputs.to(device), labels.to(device)

                    outputs = net(inputs)
                    loss = loss_fn(outputs, labels)

                    valid_loss_total += loss.item()

                    valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
                    valid_dice_total += valid_dice

                valid_loss_mean = valid_loss_total/len(valid_loader)
                valid_dice_mean = valid_dice_total/len(valid_loader)
                valid_curve.append(valid_loss_mean)
                valid_dice_curve.append(valid_dice_mean)

                print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
                    epoch, max_epoch, valid_loss_mean, valid_dice_mean))

    # 可视化
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(valid_loader):
            if idx > vis_num:
                break
            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            pred = outputs.ge(mask_thres)

            mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

            img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
            plt.subplot(121).imshow(img_hwc)
            mask_pred_gray = mask_pred.squeeze() * 255
            plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
            plt.show()
            plt.pause(0.5)
            plt.close()

    # plot curve
    train_x = range(len(train_curve))
    train_y = train_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
    valid_y = valid_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('loss value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()

    # dice curve
    train_x = range(len(train_dice_curve))
    train_y = train_dice_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_dice_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
    valid_y = valid_dice_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('dice value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()
    torch.cuda.empty_cache()

测试一个epoch,输出:

Training:Epoch[000/001] Iteration[001/212] running_loss: 0.2455, mean_loss: 0.2455 running_dice: 0.6275 lr:[0.01]
Training:Epoch[000/001] Iteration[002/212] running_loss: 0.2436, mean_loss: 0.2445 running_dice: 0.6337 lr:[0.01]
......
Training:Epoch[000/001] Iteration[210/212] running_loss: 0.0816, mean_loss: 0.1595 running_dice: 0.9295 lr:[0.01]
Training:Epoch[000/001] Iteration[211/212] running_loss: 0.1406, mean_loss: 0.1594 running_dice: 0.8416 lr:[0.01]
Training:Epoch[000/001] Iteration[212/212] running_loss: 0.1624, mean_loss: 0.1594 running_dice: 0.8296 lr:[0.01]

在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述
查看Unet结构。虽然简单,但很经典。

from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

现在使用训练过400次epoch的权重进行测试:
(注意这里使用的feature=32)

import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seed()  # 设置随机种子


def compute_dice(y_pred, y_true):
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


def get_img_name(img_dir, format="jpg"):
    """
    获取文件夹下format格式的文件名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))
    img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
    return img_names


def get_model(m_path):

    unet = UNet(in_channels=3, out_channels=1, init_features=32)
    checkpoint = torch.load(m_path, map_location="cpu")

    # remove module.
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    unet.load_state_dict(new_state_dict)

    return unet


if __name__ == "__main__":

    img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
    model_path = "checkpoint_399_epoch.pkl"
    time_total = 0
    num_infer = 5
    mask_thres = .5

    # 1. data
    img_names = get_img_name(img_dir, format="png")
    random.shuffle(img_names)
    num_img = len(img_names)

    # 2. model
    unet = get_model(model_path)
    unet.to(device)
    unet.eval()

    for idx, img_name in enumerate(img_names):
        if idx > num_infer:
            break

        path_img = os.path.join(img_dir, img_name)
        # path_img = "C:\\Users\\Administrator\\Desktop\\Andrew-wu.png"
        #
        # step 1/4 : path --> img_chw
        img_hwc = Image.open(path_img).convert('RGB')
        img_hwc = img_hwc.resize((224, 224))
        img_arr = np.array(img_hwc)
        img_chw = img_arr.transpose((2, 0, 1))

        # step 2/4 : img --> tensor
        img_tensor = torch.tensor(img_chw).to(torch.float)
        img_tensor.unsqueeze_(0)
        img_tensor = img_tensor.to(device)

        # step 3/4 : tensor --> features
        time_tic = time.time()
        outputs = unet(img_tensor)
        time_toc = time.time()

        # step 4/4 : visualization
        pred = outputs.ge(mask_thres)
        mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

        img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
        plt.subplot(121).imshow(img_hwc)
        mask_pred_gray = mask_pred.squeeze() * 255
        plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
        plt.show()
        # plt.pause(0.5)
        plt.close()

        time_s = time_toc - time_tic
        time_total += time_s

        print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

输出:
在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述

发布了76 篇原创文章 · 获赞 44 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_24739717/article/details/103353482
今日推荐