# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Class definition for instance segmentation model entity used in OTX."""
from __future__ import annotations
import logging as log
import types
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal
import numpy as np
import torch
from model_api.tilers import InstanceSegmentationTiler
from torch import Tensor
from torchmetrics import Metric, MetricCollection
from torchvision import tv_tensors
from otx.algo.explain.explain_algo import InstSegExplainAlgo, feature_vector_fn
from otx.algo.instance_segmentation.two_stage import TwoStageDetector
from otx.algo.utils.mmengine_utils import InstanceData, load_checkpoint
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity
from otx.core.data.entity.utils import stack_batch
from otx.core.metrics import MetricInput
from otx.core.metrics.fmeasure import FMeasure
from otx.core.metrics.mean_ap import MaskRLEMeanAPFMeasureCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.export import TaskLevelExportParameters
from otx.core.types.label import LabelInfoTypes
from otx.core.utils.mask_util import encode_rle, polygon_to_rle
from otx.core.utils.tile_merge import InstanceSegTileMerge
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from model_api.adapters import OpenvinoAdapter
from model_api.models.utils import InstanceSegmentationResult
from torch import nn
from otx.core.metrics import MetricCallable
[docs]
class OTXInstanceSegModel(OTXModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity]):
"""Base class for the Instance Segmentation models used in OTX."""
def __init__(
self,
label_info: LabelInfoTypes,
input_size: tuple[int, int],
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MaskRLEMeanAPFMeasureCallable,
torch_compile: bool = False,
tile_config: TileConfig = TileConfig(enable_tiler=False),
) -> None:
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
tile_config=tile_config,
)
self.input_size: tuple[int, int]
def _build_model(self, num_classes: int) -> nn.Module:
raise NotImplementedError
def _create_model(self) -> nn.Module:
detector = self._build_model(num_classes=self.label_info.num_classes)
detector.init_weights()
self.classification_layers = self.get_classification_layers("model.")
if self.load_from is not None:
load_checkpoint(detector, self.load_from, map_location="cpu")
return detector
def _customize_inputs(self, entity: InstanceSegBatchDataEntity) -> dict[str, Any]:
if isinstance(entity.images, list):
entity.images, entity.imgs_info = stack_batch(entity.images, entity.imgs_info, pad_size_divisor=32)
inputs: dict[str, Any] = {}
inputs["entity"] = entity
inputs["mode"] = "loss" if self.training else "predict"
return inputs
def _customize_outputs(
self,
outputs: list[InstanceData] | dict, # TODO (sungchul): Remove `InstanceData`.
inputs: InstanceSegBatchDataEntity,
) -> InstanceSegBatchPredEntity | OTXBatchLossEntity:
if self.training:
if not isinstance(outputs, dict):
raise TypeError(outputs)
losses = OTXBatchLossEntity()
for loss_name, loss_value in outputs.items():
if isinstance(loss_value, Tensor):
losses[loss_name] = loss_value
elif isinstance(loss_value, list):
losses[loss_name] = sum(_loss.mean() for _loss in loss_value)
# pop acc from losses
losses.pop("acc", None)
return losses
scores: list[Tensor] = []
bboxes: list[tv_tensors.BoundingBoxes] = []
labels: list[torch.LongTensor] = []
masks: list[tv_tensors.Mask] = []
predictions = outputs["predictions"] if isinstance(outputs, dict) else outputs
for img_info, prediction in zip(inputs.imgs_info, predictions):
scores.append(prediction.scores)
bboxes.append(
tv_tensors.BoundingBoxes(
prediction.bboxes,
format="XYXY",
canvas_size=img_info.ori_shape,
),
)
output_masks = tv_tensors.Mask(
prediction.masks,
dtype=torch.bool,
)
masks.append(output_masks)
labels.append(prediction.labels)
if self.explain_mode:
if not isinstance(outputs, dict):
msg = f"Model output should be a dict, but got {type(outputs)}."
raise ValueError(msg)
if "feature_vector" not in outputs:
msg = "No feature vector in the model output."
raise ValueError(msg)
if "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)
saliency_map = outputs["saliency_map"].detach().cpu().numpy()
feature_vector = outputs["feature_vector"].detach().cpu().numpy()
return InstanceSegBatchPredEntity(
batch_size=len(predictions),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks,
polygons=[],
labels=labels,
saliency_map=list(saliency_map),
feature_vector=list(feature_vector),
)
return InstanceSegBatchPredEntity(
batch_size=len(predictions),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks,
polygons=[],
labels=labels,
)
[docs]
def get_classification_layers(self, prefix: str = "") -> dict[str, dict[str, int]]:
"""Return classification layer names by comparing two different number of classes models.
Args:
config (DictConfig): Config for building model.
model_registry (Registry): Registry for building model.
prefix (str): Prefix of model param name.
Normally it is "model." since OTXModel set it's nn.Module model as self.model
Return:
dict[str, dict[str, int]]
A dictionary contain classification layer's name and information.
Stride means dimension of each classes, normally stride is 1, but sometimes it can be 4
if the layer is related bbox regression for object detection.
Extra classes is default class except class from data.
Normally it is related with background classes.
"""
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
classification_layers = {}
for key in sample_model_dict:
if sample_model_dict[key].shape != incremental_model_dict[key].shape:
sample_model_dim = sample_model_dict[key].shape[0]
incremental_model_dim = incremental_model_dict[key].shape[0]
stride = incremental_model_dim - sample_model_dim
num_extra_classes = 6 * sample_model_dim - 5 * incremental_model_dim
classification_layers[prefix + key] = {"stride": stride, "num_extra_classes": num_extra_classes}
return classification_layers
[docs]
def forward_tiles(self, inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntity]) -> InstanceSegBatchPredEntity:
"""Unpack instance segmentation tiles.
Args:
inputs (TileBatchInstSegDataEntity): Tile batch data entity.
Returns:
InstanceSegBatchPredEntity: Merged instance segmentation prediction.
"""
tile_preds: list[InstanceSegBatchPredEntity] = []
tile_attrs: list[list[dict[str, int | str]]] = []
merger = InstanceSegTileMerge(
inputs.imgs_info,
self.num_classes,
self.tile_config,
self.explain_mode,
)
for batch_tile_attrs, batch_tile_input in inputs.unbind():
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)
if isinstance(output, OTXBatchLossEntity):
msg = "Loss output is not supported for tile merging"
raise TypeError(msg)
tile_preds.append(output)
tile_attrs.append(batch_tile_attrs)
pred_entities = merger.merge(tile_preds, tile_attrs)
pred_entity = InstanceSegBatchPredEntity(
batch_size=inputs.batch_size,
images=[pred_entity.image for pred_entity in pred_entities],
imgs_info=[pred_entity.img_info for pred_entity in pred_entities],
scores=[pred_entity.score for pred_entity in pred_entities],
bboxes=[pred_entity.bboxes for pred_entity in pred_entities],
labels=[pred_entity.labels for pred_entity in pred_entities],
masks=[pred_entity.masks for pred_entity in pred_entities],
polygons=[pred_entity.polygons for pred_entity in pred_entities],
)
if self.explain_mode:
pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities]
pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities]
return pred_entity
[docs]
def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]:
"""Forward function for export."""
shape = (int(inputs.shape[2]), int(inputs.shape[3]))
meta_info = {
"pad_shape": shape,
"batch_input_shape": shape,
"img_shape": shape,
"scale_factor": (1.0, 1.0),
}
meta_info_list = [meta_info] * len(inputs)
return self.model.export(inputs, meta_info_list, explain_mode=self.explain_mode)
@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
return super()._export_parameters.wrap(
model_type="MaskRCNN",
task_type="instance_segmentation",
confidence_threshold=self.hparams.get("best_confidence_threshold", 0.05),
iou_threshold=0.5,
tile_config=self.tile_config if self.tile_config.enable_tiler else None,
)
[docs]
def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None:
"""Load state_dict from checkpoint.
For detection, it is need to update confidence threshold information when
the metric is FMeasure.
"""
if best_confidence_threshold := ckpt.get("confidence_threshold", None) or (
(hyper_parameters := ckpt.get("hyper_parameters", None))
and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None))
):
self.hparams["best_confidence_threshold"] = best_confidence_threshold
super().on_load_checkpoint(ckpt)
def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
if key == "val":
retval = super()._log_metrics(meter, key)
# NOTE: Validation metric logging can update `best_confidence_threshold`
if (
isinstance(meter, MetricCollection)
and (fmeasure := getattr(meter, "FMeasure", None))
and (best_confidence_threshold := getattr(fmeasure, "best_confidence_threshold", None))
) or (
isinstance(meter, FMeasure)
and (best_confidence_threshold := getattr(meter, "best_confidence_threshold", None))
):
self.hparams["best_confidence_threshold"] = best_confidence_threshold
return retval
if key == "test":
# NOTE: Test metric logging should use `best_confidence_threshold` found previously.
best_confidence_threshold = self.hparams.get("best_confidence_threshold", None)
compute_kwargs = (
{"best_confidence_threshold": best_confidence_threshold} if best_confidence_threshold else {}
)
return super()._log_metrics(meter, key, **compute_kwargs)
raise ValueError(key)
def _convert_pred_entity_to_compute_metric(
self,
preds: InstanceSegBatchPredEntity,
inputs: InstanceSegBatchDataEntity,
) -> MetricInput:
"""Convert the prediction entity to the format that the metric can compute and cache the ground truth.
This function will convert mask to RLE format and cache the ground truth for the current batch.
Args:
preds (InstanceSegBatchPredEntity): Current batch predictions.
inputs (InstanceSegBatchDataEntity): Current batch ground-truth inputs.
Returns:
dict[str, list[dict[str, Tensor]]]: The converted predictions and ground truth.
"""
pred_info = []
target_info = []
for bboxes, masks, scores, labels in zip(
preds.bboxes,
preds.masks,
preds.scores,
preds.labels,
):
pred_info.append(
{
"boxes": bboxes.data,
"masks": [encode_rle(mask) for mask in masks.data],
"scores": scores,
"labels": labels,
},
)
for imgs_info, bboxes, masks, polygons, labels in zip(
inputs.imgs_info,
inputs.bboxes,
inputs.masks,
inputs.polygons,
inputs.labels,
):
rles = (
[encode_rle(mask) for mask in masks.data]
if len(masks)
else polygon_to_rle(polygons, *imgs_info.ori_shape)
)
target_info.append(
{
"boxes": bboxes.data,
"masks": rles,
"labels": labels,
},
)
return {"preds": pred_info, "target": target_info}
[docs]
class ExplainableOTXInstanceSegModel(OTXInstanceSegModel):
"""OTX Instance Segmentation model which can attach a XAI (Explainable AI) branch."""
def __init__(
self,
label_info: LabelInfoTypes,
input_size: tuple[int, int],
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MaskRLEMeanAPFMeasureCallable,
torch_compile: bool = False,
tile_config: TileConfig = TileConfig(enable_tiler=False),
) -> None:
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
tile_config=tile_config,
)
self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()
self.model.get_results_from_head = self.get_results_from_head
[docs]
def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity:
"""Model forward function."""
if isinstance(inputs, OTXTileBatchDataEntity):
return self.forward_tiles(inputs)
self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()
# If customize_inputs is overridden
outputs = (
self._forward_explain_inst_seg(self.model, **self._customize_inputs(inputs))
if self._customize_inputs != ExplainableOTXInstanceSegModel._customize_inputs
else self._forward_explain_inst_seg(self.model, inputs)
)
return (
self._customize_outputs(outputs, inputs)
if self._customize_outputs != ExplainableOTXInstanceSegModel._customize_outputs
else outputs["predictions"]
)
@staticmethod
@torch.no_grad()
def _forward_explain_inst_seg(
self: TwoStageDetector,
entity: InstanceSegBatchDataEntity,
mode: str = "tensor", # noqa: ARG004
) -> dict[str, Tensor]:
"""Forward func of the BaseDetector instance, which located in is in ExplainableOTXInstanceSegModel().model."""
x = self.extract_feat(entity.images)
feature_vector = self.feature_vector_fn(x)
predictions = self.get_results_from_head(x, entity)
if isinstance(predictions, tuple) and isinstance(predictions[0], Tensor):
# Export case, consists of tensors
# For OV task saliency map are generated on MAPI side
saliency_map = torch.empty(1, dtype=torch.uint8)
elif isinstance(predictions, list) and isinstance(predictions[0], InstanceData):
# Predict case, consists of InstanceData
saliency_map = self.explain_fn(predictions)
else:
msg = f"Unexpected predictions type: {type(predictions)}"
raise TypeError(msg)
return {
"predictions": predictions,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}
[docs]
def get_results_from_head(
self,
x: tuple[Tensor],
entity: InstanceSegBatchDataEntity,
) -> tuple[Tensor] | list[InstanceData]:
"""Get the results from the head of the instance segmentation model.
Args:
x (tuple[Tensor]): The features from backbone and neck.
data_samples (OptSampleList | None): A list of data samples.
Returns:
tuple[Tensor] | list[InstanceData]: The predicted results from the head of the model.
Tuple for the Export case, list for the Predict case.
"""
from otx.algo.instance_segmentation.rtmdet_inst import RTMDetInstTiny
if isinstance(self, RTMDetInstTiny):
return self.model.bbox_head.predict(x, entity, rescale=False)
rpn_results_list = self.model.rpn_head.predict(x, entity, rescale=False)
return self.model.roi_head.predict(x, rpn_results_list, entity, rescale=True)
[docs]
def get_explain_fn(self) -> Callable:
"""Returns explain function."""
explainer = InstSegExplainAlgo(num_classes=self.num_classes)
return explainer.func
[docs]
@contextmanager
def export_model_forward_context(self) -> Iterator[None]:
"""A context manager for managing the model's forward function during model exportation.
It temporarily modifies the model's forward function to generate output sinks
for explain results during the model graph tracing.
"""
try:
self._reset_model_forward()
yield
finally:
self._restore_model_forward()
def _reset_model_forward(self) -> None:
if not self.explain_mode:
return
self.model.explain_fn = self.get_explain_fn()
forward_with_explain = self._forward_explain_inst_seg
self.original_model_forward = self.model.forward
func_type = types.MethodType
# Patch class method
model_class = type(self.model)
model_class.forward = func_type(forward_with_explain, self.model)
def _restore_model_forward(self) -> None:
if not self.explain_mode:
return
if not self.original_model_forward:
msg = "Original model forward was not saved."
raise RuntimeError(msg)
func_type = types.MethodType
self.model.forward = func_type(self.original_model_forward, self.model)
self.original_model_forward = None
[docs]
class OVInstanceSegmentationModel(
OVModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity],
):
"""Instance segmentation model compatible for OpenVINO IR inference.
It can consume OpenVINO IR model path or model name from Intel OMZ repository
and create the OTX detection model compatible for OTX testing pipeline.
"""
def __init__(
self,
model_name: str,
model_type: str = "MaskRCNN",
async_inference: bool = True,
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = MaskRLEMeanAPFMeasureCallable,
**kwargs,
) -> None:
super().__init__(
model_name=model_name,
model_type=model_type,
async_inference=async_inference,
max_num_requests=max_num_requests,
use_throughput_mode=use_throughput_mode,
model_api_configuration=model_api_configuration,
metric=metric,
)
def _setup_tiler(self) -> None:
"""Setup tiler for tile task."""
execution_mode = "async" if self.async_inference else "sync"
# Note: Disable async_inference as tiling has its own sync/async implementation
self.async_inference = False
self.model = InstanceSegmentationTiler(self.model, execution_mode=execution_mode)
log.info(
f"Enable tiler with tile size: {self.model.tile_size} \
and overlap: {self.model.tiles_overlap}",
)
def _get_hparams_from_adapter(self, model_adapter: OpenvinoAdapter) -> None:
"""Reads model configuration from ModelAPI OpenVINO adapter.
Args:
model_adapter (OpenvinoAdapter): target adapter to read the config
"""
if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]):
best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value
self.hparams["best_confidence_threshold"] = float(best_confidence_threshold)
else:
msg = (
"Cannot get best_confidence_threshold from OpenVINO IR's rt_info. "
"Please check whether this model is trained by OTX or not. "
"Without this information, it can produce a wrong F1 metric score. "
"At this time, it will be set as the default value = None."
)
log.warning(msg)
self.hparams["best_confidence_threshold"] = None
def _customize_outputs(
self,
outputs: list[InstanceSegmentationResult],
inputs: InstanceSegBatchDataEntity,
) -> InstanceSegBatchPredEntity | OTXBatchLossEntity:
# add label index
bboxes = []
scores = []
labels = []
masks = []
for output in outputs:
output_objects = output.segmentedObjects
if len(output_objects):
bbox = [[output.xmin, output.ymin, output.xmax, output.ymax] for output in output_objects]
else:
bbox = torch.empty(size=(0, 0))
bboxes.append(
tv_tensors.BoundingBoxes(
bbox,
format="XYXY",
canvas_size=inputs.imgs_info[-1].img_shape,
device=self.device,
),
)
# NOTE: OTX 1.5 filter predictions with result_based_confidence_threshold,
# but OTX 2.0 doesn't have it in configuration.
_masks = [output.mask for output in output_objects]
_masks = np.stack(_masks) if len(_masks) else []
scores.append(torch.tensor([output.score for output in output_objects], device=self.device))
masks.append(torch.tensor(_masks, device=self.device))
labels.append(torch.tensor([output.id - 1 for output in output_objects], device=self.device))
if outputs and outputs[0].saliency_map:
predicted_s_maps = []
for out in outputs:
image_map = np.array([s_map for s_map in out.saliency_map if s_map.ndim > 1])
predicted_s_maps.append(image_map)
# Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim)
predicted_f_vectors = [out.feature_vector[0] for out in outputs]
return InstanceSegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks,
polygons=[],
labels=labels,
saliency_map=predicted_s_maps,
feature_vector=predicted_f_vectors,
)
return InstanceSegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks,
polygons=[],
labels=labels,
)
def _convert_pred_entity_to_compute_metric(
self,
preds: InstanceSegBatchPredEntity,
inputs: InstanceSegBatchDataEntity,
) -> MetricInput:
"""Convert the prediction entity to the format that the metric can compute and cache the ground truth.
This function will convert mask to RLE format and cache the ground truth for the current batch.
Args:
preds (InstanceSegBatchPredEntity): Current batch predictions.
inputs (InstanceSegBatchDataEntity): Current batch ground-truth inputs.
Returns:
dict[str, list[dict[str, Tensor]]]: The converted predictions and ground truth.
"""
pred_info = []
target_info = []
for bboxes, masks, scores, labels in zip(
preds.bboxes,
preds.masks,
preds.scores,
preds.labels,
):
pred_info.append(
{
"boxes": bboxes.data,
"masks": [encode_rle(mask) for mask in masks.data],
"scores": scores,
"labels": labels,
},
)
for imgs_info, bboxes, masks, polygons, labels in zip(
inputs.imgs_info,
inputs.bboxes,
inputs.masks,
inputs.polygons,
inputs.labels,
):
rles = (
[encode_rle(mask) for mask in masks.data]
if len(masks)
else polygon_to_rle(polygons, *imgs_info.ori_shape)
)
target_info.append(
{
"boxes": bboxes.data,
"masks": rles,
"labels": labels,
},
)
return {"preds": pred_info, "target": target_info}
def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
best_confidence_threshold = self.hparams.get("best_confidence_threshold", None)
compute_kwargs = {"best_confidence_threshold": best_confidence_threshold}
return super()._log_metrics(meter, key, **compute_kwargs)