前言
SAM——分割一切
SAM是一类处理图像分割任务的通用模型。与以往只能处理某种特定类型图片的图像分割模型不同,SAM可以处理所有类型的图像。相比于以往的图像分割模型,SAM可以识别各种输入提示,确定图像中需要分割的内容,还可以灵活集成到虚拟现实/增强现实等其他系统中,且目前对于一些它未见过或相对模糊的场景,也能实现较好的图像分割效果。
github地址:
https://github.com/facebookresearch/segment-anything
一、下载代码
git clone https://github.com/facebookresearch/segment-anything
下载模型
二、测试代码
/home/diyun/work/python_project/segment-anything-main/predict.py
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
sam_checkpoint = "./checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
img_path = './input/1fd65f577cd8a38830e86cd855c3be1.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
'''
Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format
'''
print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
效果图
可能出现的错误
CUDA out of memory.
File "/home/diyun/anaconda3/envs/pytorch_gpu/lib/python3.8/site-packages/torch/nn/functional.py", line 3950, in interpolate
return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.93 GiB (GPU 0; 7.75 GiB total capacity; 1.17 GiB already allocated; 4.84 GiB free; 1.89 GiB reserved i
将 segment_anything/automatic_mask_generator.py 文件中
points_per_batch 的值改小一些
points_per_batch: int = 64,
改成
points_per_batch: int = 4,
玩法3——举行框+点来确定分割的物体
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
img_path = './input/1fd65f577cd8a38830e86cd855c3be1.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 预处理输入图片
predictor.set_image(image)
# 输入为point和box
input_box = np.array([35, 355, 1051, 1329])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
使用SAM自动标注
首先
参考:
1、SAM:Segment Anything 代码复现和测试 基本使用
2、【持续更新】Segment Anything Model (SAM)分割一切大模型相关论文和项目介绍