学习笔记|Pytorch使用教程32
本学习笔记主要摘自“深度之眼”,做一个总结,方便查阅。
使用Pytorch版本为1.2
- 图像分割是什么?
- 模型是如何将图像分割的?
- 深度学习图像分割模型简介
- 训练Unet完成人像抠图
一.图像分割是什么?
图像分割:将图像每一个像素分类
1.超像素分割:少量超像素代替大量像素,常用于图像预处理
2. 语义分割:逐像素分类,无法区分个体
3. 实例分割:对个体目标进行分割,像素级目标检测
4. 全景分割:语义分割结合实例分割
二.模型是如何将图像分割的?
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
path_img = os.path.join(BASE_DIR, "demo_img1.png")
# path_img = os.path.join(BASE_DIR, "demo_img2.png")
# path_img = os.path.join(BASE_DIR, "demo_img3.png")
# config
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 1. load data & model
input_image = Image.open(path_img).convert("RGB")
model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
model.eval()
# 2. preprocess
input_tensor = preprocess(input_image)
input_bchw = input_tensor.unsqueeze(0)
# 3. to device
if torch.cuda.is_available():
input_bchw = input_bchw.to(device)
model.to(device)
# 4. forward
with torch.no_grad():
tic = time.time()
print("input img tensor shape:{}".format(input_bchw.shape))
output_4d = model(input_bchw)['out']
output = output_4d[0]
print("pass: {:.3f}s use: {}".format(time.time() - tic, device))
print("output img tensor shape:{}".format(output.shape))
output_predictions = output.argmax(0)
# 5. visualization
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)
plt.subplot(121).imshow(r)
plt.subplot(122).imshow(input_image)
plt.show()
# appendix
classes = ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
输出:
input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 21.773s use: cpu
output img tensor shape:torch.Size([21, 433, 649])
21是表示可以分割21个类别,其中一个是背景类。
查看下一个类别:path_img = os.path.join(BASE_DIR, "demo_img2.png")
输出:
input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 20.287s use: cpu
output img tensor shape:torch.Size([21, 433, 649])
查看第三张图片:path_img = os.path.join(BASE_DIR, "demo_img3.png")
输出:
input img tensor shape:torch.Size([1, 3, 730, 574])
pass: 24.351s use: cpu
output img tensor shape:torch.Size([21, 730, 574])
三.深度学习图像分割模型简介
模型如何完成图像分割?
- 答:图像分割由模型与人类配合完成
- 模型:将数据映射到特征
- 人类:定义特征的物理意义,解决实际问题
PyTorch-Hub——PyTorch模型库,有大量模型供开发者调用
1.torch.hub.load(‘pytorch/vision’, ‘deeplabv3_resnet101’,pretrained=True)
model = torch.hub.load(github, model, *args, **kwargs)
功能:加载模型
主要参数: - github:str, 项目名,eg:pytorch/vision<repo_owner/repo_name[:tag_name]>
- model: str, 模型名
2.torch.hub.list(github, force_reload=False)
3.torch.hub.help(github, model, force_reload=False)
图像分割的思考
Ps:蓝色为小猫,绿色为小狗
深度学习中的图像分割模型
《Fully Convolutional Networks for Semantic Segmentation 》
最主要贡献:
- 利用全卷积完成pixelwise prediction
《U-Net: Convolutional Networks for Biomedical Image Segmentation》
最主要贡献:
- 奠定Unet系列分割模型的
- 基本结构 ——编码器与解码器的特征融合
- https://github.com/shawnbit/unet-family
《DeepLabv1 Semantic image segmentation with deep convolutional nets and fully connected CRFs》
DeepLab系列——V1
主要特点:
- 孔洞卷积:借助孔洞卷积,增大感受野
- CRF:采用CRF进行mask后处理
《DeepLab- Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs》
DeepLab系列——V2
主要特点:
- ASPP(Atrous spatial pyramid pooling ):解决多尺度问题
《DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation》
DeepLab系列——V3
主要特点:
- 1.孔洞卷积的串行
- 2.ASPP的并行
《DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation》
DeepLab系列——V3+
主要特点:
- deeplabv3基础上加上Encoder-Decoder思想
《Deep Semantic Segmentation of Natural and Medical Images: A Review》2019
图像分割资源:
https://github.com/shawnbit/unet-family
https://github.com/yassouali/pytorch_segmentation
四.训练Unet完成人像抠图
- 数据来源:https://github.com/PetroWu/AutoPortraitMatting
测试代码:
# -*- coding: utf-8 -*-
"""
# @file name : unet_portrait_matting.py
# @author : TingsongYu https://github.com/TingsongYu
# @date : 2019-11-25
# @brief : train unet
"""
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet
import random
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed() # 设置随机种子
def compute_dice(y_pred, y_true):
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))
if __name__ == "__main__":
# config
LR = 0.01
BATCH_SIZE = 8
max_epoch = 1 # 400
start_epoch = 0
lr_step = 150
val_interval = 3
checkpoint_interval = 20
vis_num = 10
mask_thres = 0.5
train_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "train")
valid_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
# step 1
train_set = PortraitDataset(train_dir)
valid_set = PortraitDataset(valid_dir)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)
# step 2
net = UNet(in_channels=3, out_channels=1, init_features=64) # init_features is 64 in stander uent
net.to(device)
# step 3
loss_fn = nn.MSELoss()
# step 4
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)
# step 5
train_curve = list()
valid_curve = list()
train_dice_curve = list()
valid_dice_curve = list()
for epoch in range(start_epoch, max_epoch):
train_loss_total = 0.
train_dice_total = 0.
net.train()
for iter, (inputs, labels) in enumerate(train_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
# forward
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
# print
train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu())
train_dice_curve.append(train_dice)
train_curve.append(loss.item())
train_loss_total += loss.item()
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
"running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
train_loss_total/(iter+1), train_dice, scheduler.get_lr()))
scheduler.step()
if (epoch + 1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
# validate the model
if (epoch+1) % val_interval == 0:
net.eval()
valid_loss_total = 0.
valid_dice_total = 0.
with torch.no_grad():
for j, (inputs, labels) in enumerate(valid_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_fn(outputs, labels)
valid_loss_total += loss.item()
valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
valid_dice_total += valid_dice
valid_loss_mean = valid_loss_total/len(valid_loader)
valid_dice_mean = valid_dice_total/len(valid_loader)
valid_curve.append(valid_loss_mean)
valid_dice_curve.append(valid_dice_mean)
print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
epoch, max_epoch, valid_loss_mean, valid_dice_mean))
# 可视化
with torch.no_grad():
for idx, (inputs, labels) in enumerate(valid_loader):
if idx > vis_num:
break
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
pred = outputs.ge(mask_thres)
mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")
img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
mask_pred_gray = mask_pred.squeeze() * 255
plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
plt.show()
plt.pause(0.5)
plt.close()
# plot curve
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(
valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()
# dice curve
train_x = range(len(train_dice_curve))
train_y = train_dice_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(
valid_dice_curve) + 1) * train_iters * val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_dice_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('dice value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()
torch.cuda.empty_cache()
测试一个epoch,输出:
Training:Epoch[000/001] Iteration[001/212] running_loss: 0.2455, mean_loss: 0.2455 running_dice: 0.6275 lr:[0.01]
Training:Epoch[000/001] Iteration[002/212] running_loss: 0.2436, mean_loss: 0.2445 running_dice: 0.6337 lr:[0.01]
......
Training:Epoch[000/001] Iteration[210/212] running_loss: 0.0816, mean_loss: 0.1595 running_dice: 0.9295 lr:[0.01]
Training:Epoch[000/001] Iteration[211/212] running_loss: 0.1406, mean_loss: 0.1594 running_dice: 0.8416 lr:[0.01]
Training:Epoch[000/001] Iteration[212/212] running_loss: 0.1624, mean_loss: 0.1594 running_dice: 0.8296 lr:[0.01]
查看Unet结构。虽然简单,但很经典。
from collections import OrderedDict
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()
features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet._block(features * 2, features, name="dec1")
self.conv = nn.Conv2d(
in_channels=features, out_channels=out_channels, kernel_size=1
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm2", nn.BatchNorm2d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
现在使用训练过400次epoch的权重进行测试:
(注意这里使用的feature=32)
import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed() # 设置随机种子
def compute_dice(y_pred, y_true):
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))
def get_img_name(img_dir, format="jpg"):
"""
获取文件夹下format格式的文件名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir)
img_names = list(filter(lambda x: x.endswith(format), file_names))
img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))
if len(img_names) < 1:
raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
return img_names
def get_model(m_path):
unet = UNet(in_channels=3, out_channels=1, init_features=32)
checkpoint = torch.load(m_path, map_location="cpu")
# remove module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['model_state_dict'].items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v
unet.load_state_dict(new_state_dict)
return unet
if __name__ == "__main__":
img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
model_path = "checkpoint_399_epoch.pkl"
time_total = 0
num_infer = 5
mask_thres = .5
# 1. data
img_names = get_img_name(img_dir, format="png")
random.shuffle(img_names)
num_img = len(img_names)
# 2. model
unet = get_model(model_path)
unet.to(device)
unet.eval()
for idx, img_name in enumerate(img_names):
if idx > num_infer:
break
path_img = os.path.join(img_dir, img_name)
# path_img = "C:\\Users\\Administrator\\Desktop\\Andrew-wu.png"
#
# step 1/4 : path --> img_chw
img_hwc = Image.open(path_img).convert('RGB')
img_hwc = img_hwc.resize((224, 224))
img_arr = np.array(img_hwc)
img_chw = img_arr.transpose((2, 0, 1))
# step 2/4 : img --> tensor
img_tensor = torch.tensor(img_chw).to(torch.float)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> features
time_tic = time.time()
outputs = unet(img_tensor)
time_toc = time.time()
# step 4/4 : visualization
pred = outputs.ge(mask_thres)
mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")
img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
mask_pred_gray = mask_pred.squeeze() * 255
plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
plt.show()
# plt.pause(0.5)
plt.close()
time_s = time_toc - time_tic
time_total += time_s
print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))
输出: