SAM使用提示生成对象掩码¶
分割一切模型Segment Anything Model(SAM)根据指示所需对象的提示来预测对象的掩码。该模型首先将图像转换为图像嵌入embedding,从而可以从提示中高效地生成高质量的掩码。
SamPredictor
类为使用提示模型提供了一个简单的接口。它允许用户首先使用set_image
方法设置图像,该方法计算所需的图像嵌入。然后,可以通过predict
方法提供提示,以从这些提示中高效地预测掩码。模型可以接受点提示、框提示以及来自前一个预测迭代的掩码作为输入。
from IPython.display import display, HTML
display(HTML(
"""
<a target="_blank" href="https://colab.research.google.com/github/EanYang7/segment-anything/blob/main/docs/notebooks/predictor_example.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="使用Colab打开"/>
</a>
"""
))
设置环境¶
如果在本地使用Jupyter运行,请首先按照存储库中的安装说明在您的环境中安装segment_anything
。如果在Google Colab上运行,请将下面的using_colab=True
设置为True,并运行单元格。在Colab中,请确保在“编辑”->“笔记本设置”->“硬件加速器”下选择“GPU”。
using_colab = False
if using_colab:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!mkdir images
!wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
!wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
设置¶
点、框和掩码 所需的导入和辅助函数。
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
示例图像¶
image = cv2.imread('images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()
使用SAM选择对象¶
加载SAM模型和预测器。请将下面的路径更改为指向SAM检查点的路径。建议在CUDA上运行,并使用默认模型以获得最佳结果。
import sys
# sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "../../checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
通过调用 SamPredictor.set_image
处理图像以生成图像嵌入。SamPredictor
会记住这个嵌入,并将在后续的掩码预测中使用它。
predictor.set_image(image)
要选择卡车,请选择它上面的一个点。点以 (x, y) 格式输入到模型中,并带有标签 1(前景点)或 0(背景点)。可以输入多个点;在这里我们只使用一个点。所选的点将显示为图像上的星号。
input_point = np.array([[500, 375]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
使用 SamPredictor.predict
进行预测。模型返回掩码、这些掩码的质量预测以及低分辨率掩码逻辑logits,可以传递给下一轮的预测。
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
当 multimask_output=True
(默认设置)时,SAM会输出3个掩码,其中 scores
给出了模型对这些掩码质量的自我评估。此设置适用于模棱两可的输入提示,有助于模型区分与提示一致的不同对象。当设置为 False
时,它将返回一个单一的掩码。对于模棱两可的提示,比如一个单一点,建议即使只需要一个单一的掩码也使用 multimask_output=True
;可以通过选择在 scores
中返回的分数最高的掩码来选择最佳的单一掩码。这通常会产生更好的掩码。
masks.shape # (number_of_masks) x H x W
(3, 1200, 1800)
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
Specifying a specific object with additional points¶
The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting multimask_output=False
.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
masks.shape
(1, 1200, 1800)
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
Specifying a specific object with a box¶
The model can also take a box as input, provided in xyxy format.
input_box = np.array([425, 600, 700, 875])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
Combining points and boxes¶
Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel.
input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
Batched prompt inputs¶
SamPredictor can take multiple input prompts for the same image, using predict_torch
method. This method assumes input points are already torch tensors and have already been transformed to the input frame. For example, imagine we have several box outputs from an object detector.
input_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=predictor.device)
Transform the boxes to the input frame, then predict masks. SamPredictor
stores the necessary transform as the transform
field for easy access, though it can also be instantiated directly for use in e.g. a dataloader (see segment_anything.utils.transforms
).
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W
torch.Size([4, 1, 1200, 1800])
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
End-to-end batched inference¶
If all prompts are available in advance, it is possible to run SAM directly in an end-to-end fashion. This also allows batching over images.
image1 = image # truck.jpg from above
image1_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=sam.device)
image2 = cv2.imread('images/groceries.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([
[450, 170, 520, 350],
[350, 190, 450, 350],
[500, 170, 580, 350],
[580, 170, 640, 350],
], device=sam.device)
Both images and prompts are input as PyTorch tensors that are already transformed to the correct frame. Inputs are packaged as a list over images, which each element is a dict that takes the following keys:
image
: The input image as a PyTorch tensor in CHW format.original_size
: The size of the image before transforming for input to SAM, in (H, W) format.point_coords
: Batched coordinates of point prompts.point_labels
: Batched labels of point prompts.boxes
: Batched input boxes.mask_inputs
: Batched input masks.
If a prompt is not present, the key can be excluded.
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
def prepare_image(image, transform, device):
image = transform.apply_image(image)
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()
batched_input = [
{
'image': prepare_image(image1, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
'original_size': image1.shape[:2]
},
{
'image': prepare_image(image2, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
'original_size': image2.shape[:2]
}
]
Run the model.
batched_output = sam(batched_input, multimask_output=False)
The output is a list over results for each input image, where list elements are dictionaries with the following keys:
masks
: A batched torch tensor of predicted binary masks, the size of the original image.iou_predictions
: The model's prediction of the quality for each mask.low_res_logits
: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration.
batched_output[0].keys()
dict_keys(['masks', 'iou_predictions', 'low_res_logits'])
fig, ax = plt.subplots(1, 2, figsize=(20, 20))
ax[0].imshow(image1)
for mask in batched_output[0]['masks']:
show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:
show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')
ax[1].imshow(image2)
for mask in batched_output[1]['masks']:
show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:
show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')
plt.tight_layout()
plt.show()
创建日期: November 26, 2023