otx.core.model.instance_segmentation#

Class definition for instance segmentation model entity used in OTX.

Classes

ExplainableOTXInstanceSegModel(label_info, ...)

OTX Instance Segmentation model which can attach a XAI (Explainable AI) branch.

OTXInstanceSegModel(label_info, input_size, ...)

Base class for the Instance Segmentation models used in OTX.

OVInstanceSegmentationModel(model_name, ...)

Instance segmentation model compatible for OpenVINO IR inference.

class otx.core.model.instance_segmentation.ExplainableOTXInstanceSegModel(label_info: LabelInfoTypes, model_name: str, input_size: tuple[int, int] = (1024, 1024), optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _rle_mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXInstanceSegModel

OTX Instance Segmentation model which can attach a XAI (Explainable AI) branch.

Parameters:
  • label_info (LabelInfoTypes) – label information

  • input_size (tuple[int, int]) – model input size

  • model_name (str) – model name/version

  • optimizer (OptimizerCallable, optional) – optimizer

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – scheduler

  • metric (MetricCallable, optional) – metric

  • torch_compile (bool, optional) – torch compile

  • tile_config (TileConfig, optional) – tile configuration

export_model_forward_context() Iterator[None][source]#

A context manager for managing the model’s forward function during model exportation.

It temporarily modifies the model’s forward function to generate output sinks for explain results during the model graph tracing.

forward_explain(inputs: InstanceSegBatchDataEntity) InstanceSegBatchPredEntity[source]#

Model forward function.

get_explain_fn() Callable[source]#

Returns explain function.

get_results_from_head(x: tuple[Tensor], entity: InstanceSegBatchDataEntity) tuple[Tensor, Tensor, Tensor] | list[InstanceData] | list[dict[str, Tensor]][source]#

Get the results from the head of the instance segmentation model.

Parameters:
  • x (tuple[Tensor]) – The features from backbone and neck.

  • data_samples (OptSampleList | None) – A list of data samples.

Returns:

The predicted results from the head of the model. Tuple for the Export case, list for the Predict case.

Return type:

tuple[Tensor, Tensor, Tensor] | list[InstanceData]

class otx.core.model.instance_segmentation.OTXInstanceSegModel(label_info: LabelInfoTypes, input_size: tuple[int, int] = (1024, 1024), model_name: str = 'inst_segm_model', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _rle_mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity]

Base class for the Instance Segmentation models used in OTX.

Parameters:
  • label_info (LabelInfoTypes) – label information

  • input_size (tuple[int, int]) – model input size

  • model_name (str) – model name/version

  • optimizer (OptimizerCallable, optional) – optimizer

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – scheduler

  • metric (MetricCallable, optional) – metric

  • torch_compile (bool, optional) – torch compile

  • tile_config (TileConfig, optional) – tile configuration

forward_for_tracing(inputs: Tensor) tuple[Tensor, ...][source]#

Forward function for export.

forward_tiles(inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntity]) InstanceSegBatchPredEntity[source]#

Unpack instance segmentation tiles.

Parameters:

inputs (TileBatchInstSegDataEntity) – Tile batch data entity.

Returns:

Merged instance segmentation prediction.

Return type:

InstanceSegBatchPredEntity

get_classification_layers(prefix: str = '') dict[str, dict[str, int]][source]#

Return classification layer names by comparing two different number of classes models.

Parameters:
  • config (DictConfig) – Config for building model.

  • model_registry (Registry) – Registry for building model.

  • prefix (str) – Prefix of model param name. Normally it is “model.” since OTXModel set it’s nn.Module model as self.model

Returns:

dict[str, dict[str, int]] A dictionary contain classification layer’s name and information. Stride means dimension of each classes, normally stride is 1, but sometimes it can be 4 if the layer is related bbox regression for object detection. Extra classes is default class except class from data. Normally it is related with background classes.

get_dummy_input(batch_size: int = 1) InstanceSegBatchDataEntity[source]#

Returns a dummy input for instance segmentation model.

on_load_checkpoint(ckpt: dict[str, Any]) None[source]#

Load state_dict from checkpoint.

For detection, it is need to update confidence threshold information when the metric is FMeasure.

class otx.core.model.instance_segmentation.OVInstanceSegmentationModel(model_name: str, model_type: str = 'MaskRCNN', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _rle_mean_ap_f_measure_callable>, **kwargs)[source]#

Bases: OVModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity]

Instance segmentation model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX detection model compatible for OTX testing pipeline.