语义分割——SAM分割一切代码复现

前言

SAM——分割一切
SAM是一类处理图像分割任务的通用模型。与以往只能处理某种特定类型图片的图像分割模型不同,SAM可以处理所有类型的图像。相比于以往的图像分割模型,SAM可以识别各种输入提示,确定图像中需要分割的内容,还可以灵活集成到虚拟现实/增强现实等其他系统中,且目前对于一些它未见过或相对模糊的场景,也能实现较好的图像分割效果。

github地址:
https://github.com/facebookresearch/segment-anything

一、下载代码

git clone https://github.com/facebookresearch/segment-anything

下载模型

二、测试代码

/home/diyun/work/python_project/segment-anything-main/predict.py


import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor


def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

sam_checkpoint = "./checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

img_path = './input/1fd65f577cd8a38830e86cd855c3be1.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

masks = mask_generator.generate(image)

'''
Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format
'''

print(len(masks))
print(masks[0].keys())

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

效果图

在这里插入图片描述
在这里插入图片描述

可能出现的错误

CUDA out of memory.

  File "/home/diyun/anaconda3/envs/pytorch_gpu/lib/python3.8/site-packages/torch/nn/functional.py", line 3950, in interpolate
    return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.93 GiB (GPU 0; 7.75 GiB total capacity; 1.17 GiB already allocated; 4.84 GiB free; 1.89 GiB reserved i

segment_anything/automatic_mask_generator.py 文件中

points_per_batch 的值改小一些

points_per_batch: int = 64,

改成

points_per_batch: int = 4,

玩法3——举行框+点来确定分割的物体



import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))   


sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

img_path = './input/1fd65f577cd8a38830e86cd855c3be1.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# 预处理输入图片
predictor.set_image(image)


# 输入为point和box
input_box = np.array([35, 355, 1051, 1329])
input_point = np.array([[575, 750]])
input_label = np.array([0])

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()


在这里插入图片描述

使用SAM自动标注

首先

参考:
1、SAM:Segment Anything 代码复现和测试 基本使用
2、【持续更新】Segment Anything Model (SAM)分割一切大模型相关论文和项目介绍

猜你喜欢

转载自blog.csdn.net/mao_hui_fei/article/details/137657105