Source code for datumaro.plugins.sam_transforms.automatic_mask_gen

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
"""Automatic mask generation using Segment Anything Model"""

import os.path as osp
from typing import List, Optional

import numpy as np

import datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_amg as sam_decoder_for_amg
import datumaro.plugins.sam_transforms.interpreters.sam_encoder as sam_encoder_interp
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import DatasetItem, IDataset
from datumaro.components.transformer import ModelTransform
from datumaro.plugins.inference_server_plugin import OVMSLauncher, TritonLauncher
from datumaro.plugins.inference_server_plugin.base import (
    InferenceServerType,
    ProtocolType,
    TLSConfig,
)
from datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_amg import AMGMasks, AMGPoints

__all__ = ["SAMAutomaticMaskGeneration"]


[docs] class SAMAutomaticMaskGeneration(ModelTransform, CliPlugin): """Produce instance segmentation masks automatically using Segment Anything Model (SAM). This transform can produce instance segmentation mask annotations for each given image. It samples single-point input prompts on a uniform 2D grid over the image. For each prompt, SAM can predict multiple masks. After obtaining the mask candidates, it post-processes them using the given parameters to improve quality and remove duplicates. It uses the Segment Anything Model deployed in the OpenVINO™ Model Server or NVIDIA Triton™ Inference Server instance. To launch the server instance, please see the guide in this link: https://github.com/openvinotoolkit/datumaro/tree/develop/docker/segment-anything/README.md Parameters: extractor: Dataset to transform inference_server_type: Inference server type: `InferenceServerType.ovms` or `InferenceServerType.triton` host: Host address of the server instance port: Port number of the server instance timeout: Timeout limit during communication between the client and the server instance tls_config: Configuration required if the server instance is in the secure mode protocol_type: Communication protocol type with the server instance num_workers: The number of worker threads to use for parallel inference. Set to 0 for single-process mode. Default is 0. points_per_side (int): The number of points to be sampled along one side of the image. The total number of points is points_per_side**2 on a uniform 2d grid. points_per_batch (int): Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. pred_iou_thresh (float): A filtering threshold in [0,1], using the model's predicted mask quality. stability_score_thresh (float): A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model's mask predictions. stability_score_offset (float): The amount to shift the cutoff when calculated the stability score. box_nms_thresh (float): The box IoU cutoff used by non-maximal suppression to filter duplicate masks. min_mask_region_area (int): If >0, postprocessing will be applied to remove the binary mask which has the number of 1s less than min_mask_region_area. """ def __init__( self, extractor: IDataset, inference_server_type: InferenceServerType = InferenceServerType.ovms, host: str = "localhost", port: int = 9000, timeout: float = 10.0, tls_config: Optional[TLSConfig] = None, protocol_type: ProtocolType = ProtocolType.grpc, num_workers: int = 0, points_per_side: int = 32, points_per_batch: int = 128, mask_threshold: float = 0.0, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, box_nms_thresh: float = 0.7, min_mask_region_area: int = 0, ): if inference_server_type == InferenceServerType.ovms: launcher_cls = OVMSLauncher elif inference_server_type == InferenceServerType.triton: launcher_cls = TritonLauncher else: raise ValueError(inference_server_type) self._sam_encoder_launcher = launcher_cls( model_name="sam_encoder", model_interpreter_path=osp.abspath(sam_encoder_interp.__file__), model_version=1, host=host, port=port, timeout=timeout, tls_config=tls_config, protocol_type=protocol_type, ) self._sam_decoder_launcher = launcher_cls( model_name="sam_decoder", model_interpreter_path=osp.abspath(sam_decoder_for_amg.__file__), model_version=1, host=host, port=port, timeout=timeout, tls_config=tls_config, protocol_type=protocol_type, ) self.points_per_side = points_per_side self.points_per_batch = points_per_batch self.mask_threshold = mask_threshold self.pred_iou_thresh = pred_iou_thresh self.stability_score_offset = stability_score_offset self.stability_score_thresh = stability_score_thresh self.box_nms_thresh = box_nms_thresh self.min_mask_region_area = min_mask_region_area super().__init__( extractor, launcher=self._sam_encoder_launcher, batch_size=1, append_annotation=False, num_workers=num_workers, ) @property def points_per_side(self) -> int: return self._points_per_side @points_per_side.setter def points_per_side(self, points_per_side: int) -> None: points_y = (np.arange(points_per_side) + 0.5) / points_per_side points_x = (np.arange(points_per_side) + 0.5) / points_per_side points_x = np.tile(points_x[None, :], (points_per_side, 1)) points_y = np.tile(points_y[:, None], (1, points_per_side)) self._points_grid = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) self._points_per_side = points_per_side def _process_batch( self, batch: List[DatasetItem], ) -> List[DatasetItem]: img_embeds = self._sam_encoder_launcher.launch( batch=[item for item in batch if self._sam_encoder_launcher.type_check(item)] ) items = [] for item, img_embed in zip(batch, img_embeds): amg_masks: List[AMGMasks] = [] for i in range(0, len(self._points_grid), self.points_per_batch): amg_points = [AMGPoints(points=self._points_grid[i : i + self.points_per_batch])] item_to_decode = item.wrap(annotations=amg_points + img_embed) # Nested list of mask [[mask_0, ...]] nested_masks: List[List[AMGMasks]] = self._sam_decoder_launcher.launch( [item_to_decode], stack=False, ) amg_masks += nested_masks[0] mask_anns = AMGMasks.cat(amg_masks).postprocess( mask_threshold=self.mask_threshold, pred_iou_thresh=self.pred_iou_thresh, stability_score_offset=self.stability_score_offset, stability_score_thresh=self.stability_score_thresh, box_nms_thresh=self.box_nms_thresh, min_mask_region_area=self.min_mask_region_area, ) items.append(item.wrap(annotations=mask_anns)) return items