3、Segment Anything

github

创建anaconda环境

conda create -n ASM python=3.8

下载依赖包

# pytorch>=1.7 and torchvision>=0.8
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

pip install git+https://github.com/facebookresearch/segment-anything.git
pip install opencv-python pycocotools matplotlib onnxruntime onnx

预训练权重
default or vit_h:ViT-H SAM model
vit_l:vit_l
vit_b:vit_b

example
详细的官网example
Automatically generating

代码使用

工具方法

读取图片

def read_image(path="./data/000.png"):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

展示标记框

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

展示标记点

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)

展示掩膜

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

简单使用

main方法

if __name__ == '__main__':
	# 初始化模型
    sam = init()
    predictor = SamPredictor(sam)
    # 读取图片
    image = read_image("./data/000.png")
    # 绑定图片
    predictor.set_image(image)
    # 调用自定义方法
    predict_box(image, predictor)

加载模型

def init(model_type="vit_h", sam_checkpoint="/devdata/chengan/SAM_checkpoint/sam_vit_h_4b8939.pth", device="cuda"):
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    return sam

点语义分割

def sample_use(image, predictor):
    input_points = np.array([
        [300, 300]
    ])
    # 1 (foreground point) or 0 (background point)
    input_labels = np.array([
        1
    ])
    # 掩膜,置信度,低分辨率掩码逻辑
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True
    )

    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_points, input_labels, plt.gca())
        plt.title(f"Mask {
      
      i + 1}, Score: {
      
      score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()
    print("\nmask shape", masks.shape)

点语义分割迭代

def predict_dir(image, predictor):
    input_points = np.array([
        [300, 300]
    ])
    # 1 (foreground point) or 0 (background point)
    input_labels = np.array([
        1
    ])

    # 第一次语义
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True
    )

    # Choose the model's best mask
    mask_input = logits[np.argmax(scores), :, :]
    # 第二次语义
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
    )

    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_points, input_labels, plt.gca())
    plt.axis('off')
    plt.show()

box语义分割

def predict_box(image, predictor):
    input_box = np.array([425, 600, 600, 700])
    masks, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    plt.axis('off')
    plt.show()

点 box 语义分割

def predict_box_point(image, predictor):
    input_box = np.array([425, 600, 700, 700])
    input_point = np.array([[575, 750]])
    input_label = np.array([0])
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box,
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

多个box

def predict_boxs(image, predictor):
    input_boxes = torch.tensor([
        [75, 275, 725, 750],
        [425, 600, 700, 775],
        [375, 550, 650, 700],
        [240, 675, 400, 750],
    ], device=predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    # (batch_size) x (num_predicted_masks_per_input) x H x W
    print(masks.shape)
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box in input_boxes:
        show_box(box.cpu().numpy(), plt.gca())
    plt.axis('off')
    plt.show()

batch all images

def predict_batch(images, sam):
    resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

    image1 = images[0]
    # dual with image
    image1 = resize_transform.apply_image(image1)
    image1 = torch.as_tensor(image1, device=sam.device)
    image1 = image1.permute(2, 0, 1).contiguous()
    # box
    image1_boxes = torch.tensor([
        [75, 275, 725, 750],
        [425, 600, 700, 775],
        [375, 550, 650, 800],
        [240, 675, 400, 750],
    ], device=sam.device)

    image2 = images[1]
    image2 = resize_transform.apply_image(image2)
    image2 = torch.as_tensor(image2, device=sam.device)
    image2 = image2.permute(2, 0, 1).contiguous()
    image2_boxes = torch.tensor([
        [450, 170, 520, 350],
        [350, 190, 450, 350],
        [500, 170, 580, 350],
        [580, 170, 640, 350],
    ], device=sam.device)

    """
        image: The input image as a PyTorch tensor in CHW format.
        original_size: The size of the image before transforming for input to SAM, in (H, W) format.
        point_coords: Batched coordinates of point prompts.
        point_labels: Batched labels of point prompts.
        boxes: Batched input boxes.
        mask_inputs: Batched input masks.
    """
    batched_input = [
        {
    
    
            'image': image1,
            'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
            'original_size': image1.shape[:2]
        },
        {
    
    
            'image': image2,
            'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
            'original_size': image2.shape[:2]
        }

    ]

    batched_output = sam(batched_input, multimask_output=False)
    """
    masks: A batched torch tensor of predicted binary masks, the size of the original image.
    iou_predictions: The model's prediction of the quality for each mask.
    low_res_logits: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration.
    """
    print(batched_output[0].keys())
    fig, ax = plt.subplots(1, 2, figsize=(20, 20))

    ax[0].imshow(image1)
    for mask in batched_output[0]['masks']:
        show_mask(mask.cpu().numpy(), ax[0], random_color=True)
    for box in image1_boxes:
        show_box(box.cpu().numpy(), ax[0])
    ax[0].axis('off')

    ax[1].imshow(image2)
    for mask in batched_output[1]['masks']:
        show_mask(mask.cpu().numpy(), ax[1], random_color=True)
    for box in image2_boxes:
        show_box(box.cpu().numpy(), ax[1])
    ax[1].axis('off')

    plt.tight_layout()
    plt.show()

多语义实例分割

多语义分割图片展示

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

默认方法

def sample_use(image, sam):
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    plt.figure(figsize=(20, 20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.show()
    print(len(masks))
    print(masks[0].keys())

调整输入参数

def improved_use(image, sam):
    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Requires open-cv to run post-processing
    )
    masks = mask_generator.generate(image)
    plt.figure(figsize=(20, 20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.show()
    print(len(masks))
    print(masks[0].keys())

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_50973728/article/details/134654537