前言:
2018年阿里的论文《Semantatic Human Matting》给出了抠图领域的一个新方法,可惜阿里并没有公布源码,而牛人在Github上对这个论文进行了复现,我也是依赖Github上的工程进行钻研,而在调试的过程中,发现有一些地方原作者并没有检验通过就上传,导致训练过程出错,这篇博客就是讲解如何调试通过Github上的Semantic_Human_Matting
工程的训练以及测试的代码。
–-----------------------------------------------------------------------------—--------------------------------------------
申明:
- 写博客的初衷一是为了记录,二也是为后来人填坑——测试效果的好坏受算法结构、受数据集、受训练次数等因素的影响,留言板处不要因为你的结果表现不优良而无视博主无偿付出、甚至恶评相向,这样的白嫖党我劝你善良。
–-----------------------------------------------------------------------------—--------------------------------------------
一、SHM网络简单讲解
通过下面Semantic_Human_Matting
网络图开始讲解SHM的网络设计:
SHM的网络大致分为三个部分:
T-Net
网络部分:这部分的作用主要是预测生成trimap图。网络的输入是原图 + mask图;M-Net
网络部分:这部分的作用主要是预测生成alpha图。网络的输入来源于三部分:第一个是原图(上图最左边的那张),第二个是原图对应的mask图(真正输入到网络中的mask图会被拆分成前景图 + 背景图两部分,也就是上图中的 F s F_s Fs和 B s B_s Bs),第三个是trimap图(真正输入到网络中的只要trimap图的不确定区域,也就是上图中的 U s U_s Us),预测得到上图中的 α r α_r αrFusion Module
:这部分的作用主要是融合得到精准的alpha图。最后精准 α p α_p αp遮罩图的概率估计是: α p = F s + U s α r α_p = F_s + U_sα_r αp=Fs+Usαr
–-----------------------------------------------------------------------------—--------------------------------------------
–-----------------------------------------------------------------------------—--------------------------------------------
二、SHM数据集调整说明
2.1、工程下载,以及环境配置
Github上的Semantic_Human_Matting
工程链接在此处此处此处,先下载解压;
根据工程主页上的说明,需要的是python3.5/3.6,torch>=0.4.0,以及opencv-python,我配置的机器环境是ubuntu16.04 + cuda10.0 + python3.6.12 + torch0.4.1 + opencv3.4.3。Windows机器我好像配置过,好像没通过(记不大清楚了,有兴趣的去试一试)
–-----------------------------------------------------------------------------—--------------------------------------------
2.2、下载数据集
2.2.1、最头痛的就是数据集的建立,因为建立大型数据集耗时耗力。所幸工程主页里作者给出了他找到的数据集,在这里对作者及爱分割公司表示感谢,数据集的链接在此处此处此处,密码是:dzsn,下载解压。
2.2.2、解压后可以看到其下主要包含两个文件夹:
- clip_img文件夹:其下都是原图;
- matting文件夹:其下都是原图对应的mask图,但是需要处理一下;
注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。
注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。
注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。
2.2.3、在工程data
目录下新建matting
、clip_img
文件夹,再将数据集matting
、clip_img
文件夹下的挑选任意一个相同文件夹对应放入工程目录中,隶属关系如下:
–-----------------------------------------------------------------------------—--------------------------------------------
2.3、matting
图生成对应的mask
图:
先在data
文件夹下新建zcm_matting_get_mask.py
文件,代码如下,然后执行这个py文件,完成后可以在data
目录下看到生成了一个新的mask
文件夹,其下存储着黑白底的mask图。
import os
import cv2
matting_path = "matting/"
mask_path = "mask/"
# test
# for mask_name in os.listdir(matting_path):
# in_image = cv2.imread(matting_path + mask_name, cv2.IMREAD_UNCHANGED)
# alpha = in_image[:,:,3]
# cv2.imwrite(mask_path + mask_name, alpha)
for name_0 in os.listdir(matting_path):
if not os.path.exists(mask_path + "/" + name_0):
os.makedirs(mask_path + "/" + name_0)
for name_1 in os.listdir(matting_path + "/" + name_0):
if not os.path.exists(mask_path + name_0 + "/" + name_1):
os.mkdir(mask_path + name_0 + "/" + name_1)
for name_2 in os.listdir(matting_path + "/" + name_0 + "/" + name_1):
pic_input_path = matting_path + "/" + name_0 + "/" + name_1 + "/" + name_2
pic_output_path = mask_path + "/" + name_0 + "/" + name_1 + "/" + name_2
print("pic_input_path=", pic_input_path)
in_image = cv2.imread(pic_input_path, cv2.IMREAD_UNCHANGED)
alpha = in_image[:, :, 3]
cv2.imwrite(pic_output_path, alpha)
–-----------------------------------------------------------------------------—--------------------------------------------
2.4、生成训练数据的TXT目录:
先在data
文件夹下新建zcm_get_train_txt.py
文件,代码如下,然后执行这个py文件,完成后可以在data
目录下看到生成了一个新的train.txt
文件,打开里面存储图片的路径。
import os
pic_path = "matting/"
with open("train.txt", "w", encoding="UTF-8") as ff:
for name_0 in os.listdir(pic_path):
for name_1 in os.listdir(pic_path + "/" + name_0):
for name_2 in os.listdir(pic_path + "/" + name_0 + "/" + name_1):
pic_input_path = name_0 + "/" + name_1 + "/" + name_2
ff.write(pic_input_path + "\n")
ff.close()
print("well done____________!")
–-----------------------------------------------------------------------------—--------------------------------------------
2.5、由mask图生成trimap图:
2.5.1:像下面一样注释掉gen_trimap.py
第36/42/48行的断言语句;
# assert(cnt1 == cnt2 + cnt3)
2.5.2:在gen_trimap.py
第四行添加语句,引入os库;
import os
2.5.3:在gen_trimap.py
第64行后,添加如下代码;
trimap_name_1 = trimap_name.split("/")[:-1]
trimap_path = "/".join(trimap_name_1)
if not os.path.exists(trimap_path):
os.makedirs(trimap_path)
2.5.4:执行sh gen_trimap.sh
脚本,就可以生成得到trimap
文件夹,及其其下的trimap
图片;
–-----------------------------------------------------------------------------—--------------------------------------------
2.6、生成alpha图:
说明:这里给出两种生成alpha
图的方法:
- 用工程自带的
knn_matting.sh
脚本生成alpha图; - 直接拷贝
mask
文件夹,将mask图作为精确的alpha
图注入训练;
第一种方法我在简单测试中使用过,该方法非常非常非常的耗时间,而且用该方法处理爱分割公司提供的数据集得到了alpha
图,将其注入训练后,对最后的预测的准确率的影响并不大;有兴趣的朋友可以对knn_matting
继续改进,将时间效率提高;
我也阐述使用第二种方法的依据:因为爱分割公司的数据集的mask图是精确的,是直接通过matting
文件夹生成的。爱分割公司在提供数据集的时候,mask
图就是他们人工扣出来的。而knn_matting.sh
脚本存在的意义,是对于正常情况下,我们如使用faster-RCNN,DeepLab
这样的分割算法得到的mask
图是不精准的,才需要使用knn_matting
算法处理边界,得到精准的alpha
图。
所以这一步,在data
文件夹下新建alpha
文件夹后,再执行下面复制语句,将mask
文件夹下所有文件复制到alpha
文件夹;
cp -r mask/* alpha/
至此,数据集准备工作全部做完。
–-----------------------------------------------------------------------------—--------------------------------------------
–-----------------------------------------------------------------------------—--------------------------------------------
三、训练细节调整说明
3.1、写入训练code:
先在Semantic_Human_Matting工程目录下,新建train_code.txt
文件,写入如下指令:
# # T-Net训练指令
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=1e-5 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='pre_train_t_net'
# # M-Net训练指令
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=5e-6 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='end_to_end'
第一段是T-Net
训练代码,第二段是M-Net
训练代码
–-----------------------------------------------------------------------------—--------------------------------------------
3.2、修改train.py
文件:
在train.py
文件第29行后添加一条语句,用来指示GPU的使用情况
parser.add_argument('--gpus', default='0,1,2,3', help='gpus number')
–-----------------------------------------------------------------------------—--------------------------------------------
3.3、修改dataset.py
文件:
3.3.1:用如下语句替换dataset.py
文件第17/18/19行
image_name = os.path.join(data_dir, 'clip_img', file_name['image'].replace("matting", "clip").replace("png", "jpg"))
trimap_name = os.path.join(data_dir, 'trimap', file_name['trimap'].replace("clip", "matting"))
alpha_name = os.path.join(data_dir, 'alpha', file_name['alpha'].replace("clip", "matting"))
3.3.2:用如下语句替换dataset.py
文件第101/102/103行:
trimap[trimap == 0] = 0
trimap[trimap >= 250] = 2
trimap[np.where(~((trimap == 0) | (trimap == 2)))] = 1
这里是整个代码中错误最隐蔽的一个,当初也是花了我很长时间才搞定。我解释一下为什么这样做:我们知道trimap
图是三色图,但是它的“三色”并不像上图中0/128/255
只有这三色,它是在[0, 255]
这个区间范围内。所以新改的代码,将这“三色”用区间区分,作为三种不同的label传入训练。
–-----------------------------------------------------------------------------—--------------------------------------------
3.4、开启T-Net
训练:
运行train_code.txt
第一行代码,开启T-Net
训练,如果你报内存不足的错误,就适当调小patch_size,nThreads,train_batch
的数值;
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=1e-5 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='pre_train_t_net'
下图是我T-Net
训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;
–-----------------------------------------------------------------------------—--------------------------------------------
3.5、开启M-Net
训练:
运行train_code.txt
第二行代码,开启M-Net
微调训练
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=5e-6 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='end_to_end'
下图是我M-Net
训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;
–-----------------------------------------------------------------------------—--------------------------------------------
–-----------------------------------------------------------------------------—--------------------------------------------
四、测试细节调整说明
4.1:新建test_camera_used.py
文件
写入如下代码,代码与test_camera.py
文件很相似,只是改了一部分需求,让过程更简洁;
'''
test camera
Author: Zhengwei Li
Date : 2018/12/28
'''
import time
import cv2
import torch
import argparse
import numpy as np
import os
import torch.nn.functional as F
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'
parser = argparse.ArgumentParser(description='human matting')
parser.add_argument('--model', default='./ckpt/human_matting/model/model_obj.pth', help='preTrained model')
parser.add_argument('--size', type=int, default=320, help='input size')
parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu')
args = parser.parse_args()
torch.set_grad_enabled(False)
#################################
#----------------
if args.without_gpu:
print("use CPU !")
device = torch.device('cpu')
else:
if torch.cuda.is_available():
n_gpu = torch.cuda.device_count()
print("----------------------------------------------------------")
print("| use GPU ! || Available GPU number is {} ! |".format(n_gpu))
print("----------------------------------------------------------")
device = torch.device('cuda: 0, 1, 2, 3')
#################################
#---------------
def load_model(args):
print('Loading model from {}...'.format(args.model))
if args.without_gpu:
myModel = torch.load(args.model, map_location=lambda storage, loc: storage)
else:
myModel = torch.load(args.model)
myModel.eval()
myModel.to(device)
# myModel.cuda()
return myModel
def seg_process(args, image, net):
# opencv
origin_h, origin_w, c = image.shape
image_resize = cv2.resize(image, (args.size,args.size), interpolation=cv2.INTER_CUBIC)
image_resize = (image_resize - (104., 112., 121.,)) / 255.0
tensor_4D = torch.FloatTensor(1, 3, args.size, args.size)
tensor_4D[0,:,:,:] = torch.FloatTensor(image_resize.transpose(2,0,1))
inputs = tensor_4D.to(device)
trimap, alpha = net(inputs)
trimap_np = trimap[0, 0, :, :].cpu().data.numpy()
trimap_np = cv2.resize(trimap_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)
mask_result = np.multiply(trimap_np[..., np.newaxis], image)
trimap_1 = mask_result.copy()
mask_result[trimap_1 < 10] = 255
mask_result[trimap_1 >= 10] = 0
cv2.imwrite("mask_result.png", mask_result)
if args.without_gpu:
alpha_np = alpha[0,0,:,:].data.numpy()
else:
alpha_np = alpha[0,0,:,:].cpu().data.numpy()
alpha_np = cv2.resize(alpha_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)
fg = np.multiply(alpha_np[..., np.newaxis], image)
# cv2.imwrite("fg.png", fg)
# bg = image
# bg_gray = np.multiply(1 - alpha_np[..., np.newaxis], image)
# bg_gray = cv2.cvtColor(bg_gray, cv2.COLOR_BGR2GRAY)
# # print("bg_gray=", bg_gray)
# bg[:,:,0] = bg_gray
# bg[:,:,1] = bg_gray
# bg[:,:,2] = bg_gray
#
# # fg[fg<=0] = 0
# # fg[fg>255] = 255
# # fg = fg.astype(np.uint8)
# # out = cv2.addWeighted(fg, 0.7, bg, 0.3, 0)
#
# # out = fg + bg
# # out[out<0] = 0
# # out[out>255] = 255
# # out = out.astype(np.uint8)
#
# out = fg.copy()
# out[out<10] = 0
# out[out>=10] = 255
# out = out.astype(np.uint8)
return fg, mask_result
def camera_seg(args, net):
# videoCapture = cv2.VideoCapture(0)
#
# while(1):
# # get a frame
# ret, frame = videoCapture.read()
# frame = cv2.flip(frame,1)
# frame_seg = seg_process(args, frame, net)
#
#
# # show a frame
# cv2.imshow("capture", frame_seg)
#
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# videoCapture.release()
test_pic_path = "test_pic/"
output_path = "result/"
if not os.path.exists(output_path):
os.mkdir(output_path)
time_0 = time.time()
for name_ in os.listdir(test_pic_path):
frame = cv2.imread(test_pic_path + name_)
fg, mask_result = seg_process(args, frame, net)
print("SUCCESS_____!", test_pic_path + name_)
cv2.imwrite(output_path + name_.split(".")[0] + "_fg.jpg", fg)
cv2.imwrite(output_path + name_, mask_result)
print("time_all = ", time.time() - time_0)
def main(args):
time_1 = time.time()
myModel = load_model(args)
print("lodding_model_time = ", time.time() - time_1)
camera_seg(args, myModel)
if __name__ == "__main__":
main(args)
4.2:测试过程
在主目录下新建test_pic
文件夹,将测试所用的pic图片存入其中后,运行test_camera_used.py
文件,就能在result
文件夹下得到预测的结果图。
–-----------------------------------------------------------------------------—--------------------------------------------
–-----------------------------------------------------------------------------—--------------------------------------------
五、最后的说明:
- 爱分割公司提供的数据集中,某一个目录中有一个没用的隐藏文件,如果不删除的话,数据准备过程、训练过程会报错——但是我忘了具体在哪个文件夹…
- 我训练了一个较好的model,所用的设备是具有4个GTX2080的显卡服务器跑了将近10天,用上了爱分割公司全部数据集 + 自建的一些数据集,因为公司的保密协议,我不能公布这个model,只展示我测试的结果。左边是预测生成图,右边是原图;
- 有问题欢迎留言垂询;