datumaro.plugins.sam_transforms.automatic_mask_gen#

Automatic mask generation using Segment Anything Model

Classes

SAMAutomaticMaskGeneration(extractor[, ...])

Produce instance segmentation masks automatically using Segment Anything Model (SAM).

class datumaro.plugins.sam_transforms.automatic_mask_gen.SAMAutomaticMaskGeneration(extractor: IDataset, inference_server_type: InferenceServerType = InferenceServerType.ovms, host: str = 'localhost', port: int = 9000, timeout: float = 10.0, tls_config: TLSConfig | None = None, protocol_type: ProtocolType = ProtocolType.grpc, num_workers: int = 0, points_per_side: int = 32, points_per_batch: int = 128, mask_threshold: float = 0.0, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, box_nms_thresh: float = 0.7, min_mask_region_area: int = 0)[source]#

Bases: ModelTransform, CliPlugin

Produce instance segmentation masks automatically using Segment Anything Model (SAM).

This transform can produce instance segmentation mask annotations for each given image. It samples single-point input prompts on a uniform 2D grid over the image. For each prompt, SAM can predict multiple masks. After obtaining the mask candidates, it post-processes them using the given parameters to improve quality and remove duplicates.

It uses the Segment Anything Model deployed in the OpenVINO™ Model Server or NVIDIA Triton™ Inference Server instance. To launch the server instance, please see the guide in this link: openvinotoolkit/datumaro

Parameters:
  • extractor – Dataset to transform

  • inference_server_type – Inference server type: InferenceServerType.ovms or InferenceServerType.triton

  • host – Host address of the server instance

  • port – Port number of the server instance

  • timeout – Timeout limit during communication between the client and the server instance

  • tls_config – Configuration required if the server instance is in the secure mode

  • protocol_type – Communication protocol type with the server instance

  • num_workers – The number of worker threads to use for parallel inference. Set to 0 for single-process mode. Default is 0.

  • points_per_side (int) – The number of points to be sampled along one side of the image. The total number of points is points_per_side**2 on a uniform 2d grid.

  • points_per_batch (int) – Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.

  • pred_iou_thresh (float) – A filtering threshold in [0,1], using the model’s predicted mask quality.

  • stability_score_thresh (float) – A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model’s mask predictions.

  • stability_score_offset (float) – The amount to shift the cutoff when calculated the stability score.

  • box_nms_thresh (float) – The box IoU cutoff used by non-maximal suppression to filter duplicate masks.

  • min_mask_region_area (int) – If >0, postprocessing will be applied to remove the binary mask which has the number of 1s less than min_mask_region_area.

property points_per_side: int#
class datumaro.plugins.sam_transforms.automatic_mask_gen.AMGMasks(*, id: int = 0, attributes: Dict[str, Any] = _Nothing.NOTHING, group: int = 0, object_id: int = -1, masks: ndarray, iou_preds: ndarray)[source]#

Bases: Annotation

Intermediate annotation class for SAM decoder outputs.

masks#

Array of masks corresponded to the points.

Type:

numpy.ndarray

iou_preds#

Array of Intersection over Union (IoU) prediction scores corresponded to the points.

Type:

numpy.ndarray

Method generated by attrs for class AMGMasks.

masks: ndarray#
iou_preds: ndarray#
classmethod cat(masks: List[AMGMasks]) AMGMasks[source]#

Concatenate a list of AMGMasks into a single AMGMasks object.

Parameters:

masks – List of AMGMasks to concatenate.

Returns:

A new AMGMasks containing the concatenated masks and IoU prediction scores.

postprocess(mask_threshold: float, pred_iou_thresh: float, stability_score_offset: float, stability_score_thresh: float, box_nms_thresh: float, min_mask_region_area: int) List[Mask][source]#

Postprocesses the masks with the given parameters.

Parameters:
  • pred_iou_thresh (float) – A filtering threshold in [0,1], using the model’s predicted mask quality.

  • stability_score_thresh (float) – A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model’s mask predictions.

  • stability_score_offset (float) – The amount to shift the cutoff when calculated the stability score.

  • box_nms_thresh (float) – The box IoU cutoff used by non-maximal suppression to filter duplicate masks.

  • min_mask_region_area (int) – If >0, postprocessing will be applied to remove the binary mask which has the number of 1s less than min_mask_region_area.

Returns:

List of :class:`Mask`s representing the postprocessed masks.

class datumaro.plugins.sam_transforms.automatic_mask_gen.AMGPoints(*, id: int = 0, attributes: Dict[str, Any] = _Nothing.NOTHING, group: int = 0, object_id: int = -1, points: ndarray)[source]#

Bases: Annotation

Intermediate annotation class for SAM decoder inputs.

points#

Array of points (x, y) for the SAM prompt.

Type:

numpy.ndarray

Method generated by attrs for class AMGPoints.

points: ndarray#
class datumaro.plugins.sam_transforms.automatic_mask_gen.CliPlugin[source]#

Bases: object

NAME = 'cli_plugin'#
classmethod build_cmdline_parser(**kwargs)[source]#
classmethod parse_cmdline(args=None)[source]#
class datumaro.plugins.sam_transforms.automatic_mask_gen.DatasetItem(id: str, *, subset: str | None = None, media: str | MediaElement | None = None, annotations: List[Annotation] | None = None, attributes: Dict[str, Any] | None = None)[source]#

Bases: object

id: str#
subset: str#
media: MediaElement | None#
annotations: Annotations#
attributes: Dict[str, Any]#
wrap(**kwargs)[source]#
media_as(t: Type[T]) T[source]#
class datumaro.plugins.sam_transforms.automatic_mask_gen.IDataset[source]#

Bases: object

subsets() Dict[str, IDataset][source]#

Enumerates subsets in the dataset. Each subset can be a dataset itself.

get_subset(name) IDataset[source]#
infos() Dict[str, Any][source]#

Returns meta-info of dataset.

categories() Dict[AnnotationType, Categories][source]#

Returns metainfo about dataset labels.

get(id: str, subset: str | None = None) DatasetItem | None[source]#

Provides random access to dataset items.

media_type() Type[MediaElement][source]#

Returns media type of the dataset items.

All the items are supposed to have the same media type. Supposed to be constant and known immediately after the object construction (i.e. doesn’t require dataset iteration).

ann_types() List[AnnotationType][source]#

Returns available task type from dataset annotation types.

property is_stream: bool#

Boolean indicating whether the dataset is a stream

If the dataset is a stream, the dataset item is generated on demand from its iterator.

class datumaro.plugins.sam_transforms.automatic_mask_gen.InferenceServerType(value)[source]#

Bases: IntEnum

Types of the dedicated inference server

ovms = 0#
triton = 1#
class datumaro.plugins.sam_transforms.automatic_mask_gen.ModelTransform(extractor: IDataset, launcher: Launcher, batch_size: int = 1, append_annotation: bool = False, num_workers: int = 0)[source]#

Bases: Transform

A transformation class for applying a model’s inference to dataset items.

This class takes an dataset, a launcher, and other optional parameters to transform the dataset item from the model outputs by the launcher. It can process items using multiple processes if specified, making it suitable for parallelized inference tasks.

Parameters:
  • extractor – The dataset extractor to obtain items from.

  • launcher – The launcher responsible for model inference.

  • batch_size – The batch size for processing items. Default is 1.

  • append_annotation – Whether to append inference annotations to existing annotations. Default is False.

  • num_workers – The number of worker threads to use for parallel inference. Set to 0 for single-process mode. Default is 0.

get_subset(name)[source]#
infos()[source]#

Returns meta-info of dataset.

categories()[source]#

Returns metainfo about dataset labels.

transform_item(item)[source]#
class datumaro.plugins.sam_transforms.automatic_mask_gen.OVMSLauncher(model_name: str, model_interpreter_path: str, model_version: int = 0, host: str = 'localhost', port: int = 9000, timeout: float = 10.0, tls_config: TLSConfig | None = None, protocol_type: ProtocolType = ProtocolType.grpc)[source]#

Bases: LauncherForDedicatedInferenceServer[Union[GrpcClient, HttpClient]]

Inference launcher for OVMS (OpenVINO™ Model Server) (openvinotoolkit/model_server)

Parameters:
  • model_name – Name of the model. It should match with the model name loaded in the server instance.

  • model_interpreter_path – Python source code path which implements a model interpreter. The model interpreter implement pre-processing of the model input and post-processing of the model output.

  • model_version – Version of the model loaded in the server instance

  • host – Host address of the server instance

  • port – Port number of the server instance

  • timeout – Timeout limit during communication between the client and the server instance

  • tls_config – Configuration required if the server instance is in the secure mode

  • protocol_type – Communication protocol type with the server instance

infer(inputs: ndarray | Dict[str, ndarray]) List[Dict[str, ndarray] | List[Dict[str, ndarray]]][source]#
class datumaro.plugins.sam_transforms.automatic_mask_gen.ProtocolType(value)[source]#

Bases: IntEnum

Protocol type for communication with dedicated inference server

grpc = 0#
http = 1#
class datumaro.plugins.sam_transforms.automatic_mask_gen.TLSConfig(client_key_path: str, client_cert_path: str, server_cert_path: str)[source]#

Bases: object

TLS configuration dataclass

Parameters:
  • client_key_path – Path to client key file

  • client_cert_path – Path to client certificate file

  • server_cert_path – Path to server certificate file

client_key_path: str#
client_cert_path: str#
server_cert_path: str#
as_dict() Dict[str, str][source]#
as_grpc_creds() ChannelCredentials[source]#
class datumaro.plugins.sam_transforms.automatic_mask_gen.TritonLauncher(model_name: str, model_interpreter_path: str, model_version: int = 0, host: str = 'localhost', port: int = 9000, timeout: float = 10.0, tls_config: TLSConfig | None = None, protocol_type: ProtocolType = ProtocolType.grpc)[source]#

Bases: LauncherForDedicatedInferenceServer[Union[InferenceServerClient, InferenceServerClient]]

Inference launcher for Triton Inference Server (triton-inference-server)

Parameters:
  • model_name – Name of the model. It should match with the model name loaded in the server instance.

  • model_interpreter_path – Python source code path which implements a model interpreter. The model interpreter implement pre-processing of the model input and post-processing of the model output.

  • model_version – Version of the model loaded in the server instance

  • host – Host address of the server instance

  • port – Port number of the server instance

  • timeout – Timeout limit during communication between the client and the server instance

  • tls_config – Configuration required if the server instance is in the secure mode

  • protocol_type – Communication protocol type with the server instance

infer(inputs: ndarray | Dict[str, ndarray]) List[Dict[str, ndarray] | List[Dict[str, ndarray]]][source]#