兄弟们,我杀疯了,一日两更。推理片已经写好了,请查收
先给大家一下整个项目的链接:
链接:https://pan.baidu.com/s/1hwWF4-rfpiUBGUAZfWIgTQ
提取码:qlwo
项目的文件结构是这样的:
model文件夹中,提供了一个我训练好的模型,因为deeplabv3+的模型还是比较大的,大概210M,所以建议开个百度云会员(这里也建议百度云给我点广告费,或者免除我的会员费)
DATA是这个项目用到的数据集和标注好的标注数据。有兴趣可以自行下载来跑一跑
给大家,看看效果吧
由于训练数据还是比较少的,而且我标注也比较随意,所以轮廓不是特别的圆润,希望大家理解
5.推理代码
import torch
import cv2
from torch.nn import functional as F
import torchvision.transforms as transform
import numpy as np
import config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tf = transform.Compose([transform.ToTensor(), transform.Normalize([0.5], [0.5])])
class Inference:
def __init__(self, inp_model, threshold, re_size, classes):
self.inp_model = inp_model
self.inp_model.eval()
self.size = re_size
self.threshold = threshold
self.classes = classes
self.num_classes = len(classes)
def __call__(self, img):
masks_list = []
img_size = img.shape[:2]
img = cv2.resize(img, tuple(self.size))
img = tf(img).unsqueeze(0).to(device)
with torch.no_grad():
out = self.inp_model(img)
out = F.interpolate(out, size=img_size, mode="bicubic", align_corners=False)
w, h = out.shape[2:]
back_matrix = torch.ones(size=(1, w, h)) * self.threshold
back_matrix = back_matrix.to(device)
pr = out[0]
pr = torch.cat([pr, back_matrix], dim=0)
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy().argmax(axis=-1)
seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], self.num_classes))
for c in range(self.num_classes):
seg_img[:, :, c] += ((pr[:, :] == c) * (255)).astype('uint8')
num_sum = np.sum(seg_img[:, :, c])
if num_sum > 0:
masks_list.append([np.uint8(seg_img[:, :, c]), self.classes[c]])
return masks_list
if __name__ == '__main__':
net = torch.load(r"D:\blog_project\guligedong_segmentation\model\net.pth",
map_location='cuda' if torch.cuda.is_available() else 'cpu').to(device)
img = cv2.imread("sample.jpg")
img_ori = img.copy()
color_map = [[255,0,0],[0,0,255]]
model = Inference(inp_model=net, threshold=0.95, re_size=config.train_img_size, classes=["眼睛","嘴巴"])
mask_list = model(img)
for i in range(len(mask_list)):
mask = np.array(mask_list[i][0])
conters,_ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for cnt in conters:
if cv2.contourArea(cnt) > 100:
cv2.drawContours(img,[cnt],0, (0, 255, 0), 4)
else:
box = cv2.boundingRect(cnt)
x, y, w, h = box
mask[y:y+h,x:x+w] = 0
img[:, :, :][mask[:, :] > 0] = color_map[i]
cv2.imshow('seg',img)
cv2.imshow('ori',img_ori)
cv2.waitKey()
好了,整个项目到这里也基本结束了,segmentation我在这里也只是做了一个简单的演示,兄弟们还是要好好自己亲手去写一写代码,当然,第一步还是慢慢看明白这个代码。
至此,敬礼,salute!!!
还是来一波彩蛋吧,咩咩狗,出来吧!!!!