Source code for otx.core.model.detection

# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Class definition for detection model entity used in OTX."""

from __future__ import annotations

import logging as log
import types
from abc import abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal

import torch
from model_api.tilers import DetectionTiler
from torchmetrics import Metric, MetricCollection
from torchvision import tv_tensors

from otx.algo.utils.mmengine_utils import InstanceData, load_checkpoint
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity
from otx.core.data.entity.utils import stack_batch
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.fmeasure import FMeasure, MeanAveragePrecisionFMeasureCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.export import TaskLevelExportParameters
from otx.core.types.label import LabelInfoTypes
from otx.core.utils.tile_merge import DetectionTileMerge

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
    from model_api.adapters import OpenvinoAdapter
    from model_api.models.utils import DetectionResult
    from torch import nn

    from otx.algo.detection.detectors import SingleStageDetector


[docs] class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity]): """Base class for the detection models used in OTX.""" input_size: tuple[int, int] def __init__(self, model_name: str, *args, **kwargs) -> None: self.model_name = model_name super().__init__(*args, **kwargs)
[docs] def test_step(self, batch: DetBatchDataEntity, batch_idx: int) -> None: """Perform a single test step on a batch of data from the test set. :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. :param batch_idx: The index of the current batch. """ preds = self._filter_outputs_by_threshold(self.forward(inputs=batch)) # type: ignore[arg-type] if isinstance(preds, OTXBatchLossEntity): raise TypeError(preds) metric_inputs = self._convert_pred_entity_to_compute_metric(preds, batch) if isinstance(metric_inputs, dict): self.metric.update(**metric_inputs) return if isinstance(metric_inputs, list) and all(isinstance(inp, dict) for inp in metric_inputs): for inp in metric_inputs: self.metric.update(**inp) return raise TypeError(metric_inputs)
[docs] def predict_step( self, batch: DetBatchDataEntity, batch_idx: int, dataloader_idx: int = 0, ) -> DetBatchPredEntity: """Step function called during PyTorch Lightning Trainer's predict.""" if self.explain_mode: return self._filter_outputs_by_threshold(self.forward_explain(inputs=batch)) outputs = self._filter_outputs_by_threshold(self.forward(inputs=batch)) # type: ignore[arg-type] if isinstance(outputs, OTXBatchLossEntity): raise TypeError(outputs) return outputs
def _filter_outputs_by_threshold(self, outputs: DetBatchPredEntity) -> DetBatchPredEntity: scores = [] bboxes = [] labels = [] for score, bbox, label in zip(outputs.scores, outputs.bboxes, outputs.labels): filtered_idx = torch.where(score > self.best_confidence_threshold) scores.append(score[filtered_idx]) bboxes.append(tv_tensors.wrap(bbox[filtered_idx], like=bbox)) labels.append(label[filtered_idx]) outputs.scores = scores outputs.bboxes = bboxes outputs.labels = labels return outputs @abstractmethod def _build_model(self, num_classes: int) -> nn.Module: raise NotImplementedError def _create_model(self) -> nn.Module: detector = self._build_model(num_classes=self.label_info.num_classes) if hasattr(detector, "init_weights"): detector.init_weights() self.classification_layers = self.get_classification_layers(prefix="model.") if self.load_from is not None: load_checkpoint(detector, self.load_from, map_location="cpu") return detector def _customize_inputs( self, entity: DetBatchDataEntity, pad_size_divisor: int = 32, pad_value: int = 0, ) -> dict[str, Any]: if isinstance(entity.images, list): entity.images, entity.imgs_info = stack_batch( entity.images, entity.imgs_info, pad_size_divisor=pad_size_divisor, pad_value=pad_value, ) inputs: dict[str, Any] = {} inputs["entity"] = entity inputs["mode"] = "loss" if self.training else "predict" return inputs def _customize_outputs( # type: ignore[override] self, outputs: list[InstanceData] | dict | None, inputs: DetBatchDataEntity, ) -> DetBatchPredEntity | OTXBatchLossEntity | None: if self.training: if outputs is None: return outputs if not isinstance(outputs, dict): raise TypeError(outputs) losses = OTXBatchLossEntity() for k, v in outputs.items(): if isinstance(v, list): losses[k] = sum(v) elif isinstance(v, torch.Tensor): losses[k] = v else: msg = f"Loss output should be list or torch.tensor but got {type(v)}" raise TypeError(msg) return losses scores = [] bboxes = [] labels = [] predictions = outputs["predictions"] if isinstance(outputs, dict) else outputs for img_info, prediction in zip(inputs.imgs_info, predictions): if not isinstance(prediction, InstanceData): raise TypeError(prediction) scores.append(prediction.scores) # type: ignore[attr-defined] bboxes.append( tv_tensors.BoundingBoxes( prediction.bboxes, # type: ignore[attr-defined] format="XYXY", canvas_size=img_info.ori_shape, ), ) labels.append(prediction.labels) # type: ignore[attr-defined] if self.explain_mode: if not isinstance(outputs, dict): msg = f"Model output should be a dict, but got {type(outputs)}." raise ValueError(msg) if "feature_vector" not in outputs: msg = "No feature vector in the model output." raise ValueError(msg) if "saliency_map" not in outputs: msg = "No saliency maps in the model output." raise ValueError(msg) saliency_map = outputs["saliency_map"].detach().cpu().numpy() feature_vector = outputs["feature_vector"].detach().cpu().numpy() return DetBatchPredEntity( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, labels=labels, saliency_map=saliency_map, feature_vector=feature_vector, ) return DetBatchPredEntity( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, labels=labels, )
[docs] def get_classification_layers(self, prefix: str = "model.") -> dict[str, dict[str, int]]: """Get final classification layer information for incremental learning case.""" sample_model_dict = self._build_model(num_classes=5).state_dict() incremental_model_dict = self._build_model(num_classes=6).state_dict() classification_layers = {} for key in sample_model_dict: if sample_model_dict[key].shape != incremental_model_dict[key].shape: sample_model_dim = sample_model_dict[key].shape[0] incremental_model_dim = incremental_model_dict[key].shape[0] stride = incremental_model_dim - sample_model_dim num_extra_classes = 6 * sample_model_dim - 5 * incremental_model_dim classification_layers[prefix + key] = {"stride": stride, "num_extra_classes": num_extra_classes} return classification_layers
[docs] def forward_tiles(self, inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) -> DetBatchPredEntity: """Unpack detection tiles. Args: inputs (TileBatchDetDataEntity): Tile batch data entity. Returns: DetBatchPredEntity: Merged detection prediction. """ tile_preds: list[DetBatchPredEntity] = [] tile_attrs: list[list[dict[str, int | str]]] = [] merger = DetectionTileMerge( inputs.imgs_info, self.num_classes, self.tile_config, self.explain_mode, ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): msg = "Loss output is not supported for tile merging" raise TypeError(msg) tile_preds.append(output) tile_attrs.append(batch_tile_attrs) pred_entities = merger.merge(tile_preds, tile_attrs) pred_entity = DetBatchPredEntity( batch_size=inputs.batch_size, images=[pred_entity.image for pred_entity in pred_entities], imgs_info=[pred_entity.img_info for pred_entity in pred_entities], scores=[pred_entity.score for pred_entity in pred_entities], bboxes=[pred_entity.bboxes for pred_entity in pred_entities], labels=[pred_entity.labels for pred_entity in pred_entities], ) if self.explain_mode: pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities] pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities] return pred_entity
[docs] def forward_for_tracing(self, inputs: torch.Tensor) -> list[InstanceData]: """Forward function for export.""" shape = (int(inputs.shape[2]), int(inputs.shape[3])) meta_info = { "pad_shape": shape, "batch_input_shape": shape, "img_shape": shape, "scale_factor": (1.0, 1.0), } meta_info_list = [meta_info] * len(inputs) return self.model.export(inputs, meta_info_list, explain_mode=self.explain_mode)
@property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" return super()._export_parameters.wrap( model_type="ssd", task_type="detection", confidence_threshold=self.hparams.get("best_confidence_threshold", None), iou_threshold=0.5, tile_config=self.tile_config if self.tile_config.enable_tiler else None, ) def _convert_pred_entity_to_compute_metric( self, preds: DetBatchPredEntity, inputs: DetBatchDataEntity, ) -> MetricInput: return { "preds": [ { "boxes": bboxes.data, "scores": scores.type(torch.float32), "labels": labels, } for bboxes, scores, labels in zip( preds.bboxes, preds.scores, preds.labels, ) ], "target": [ { "boxes": bboxes.data, "labels": labels, } for bboxes, labels in zip(inputs.bboxes, inputs.labels) ], }
[docs] def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None: """Load state_dict from checkpoint. For detection, it is need to update confidence threshold information when the metric is FMeasure. """ if best_confidence_threshold := ckpt.get("confidence_threshold", None) or ( (hyper_parameters := ckpt.get("hyper_parameters", None)) and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None)) ): self.hparams["best_confidence_threshold"] = best_confidence_threshold super().on_load_checkpoint(ckpt)
def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: if key == "val": retval = super()._log_metrics(meter, key) # NOTE: Validation metric logging can update `best_confidence_threshold` if ( isinstance(meter, MetricCollection) and (fmeasure := getattr(meter, "FMeasure", None)) and (best_confidence_threshold := getattr(fmeasure, "best_confidence_threshold", None)) ) or ( isinstance(meter, FMeasure) and (best_confidence_threshold := getattr(meter, "best_confidence_threshold", None)) ): self.hparams["best_confidence_threshold"] = best_confidence_threshold return retval if key == "test": # NOTE: Test metric logging should use `best_confidence_threshold` found previously. best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) compute_kwargs = ( {"best_confidence_threshold": best_confidence_threshold} if best_confidence_threshold else {} ) return super()._log_metrics(meter, key, **compute_kwargs) raise ValueError(key) @property def best_confidence_threshold(self) -> float: """Best confidence threshold to filter outputs.""" if not hasattr(self, "_best_confidence_threshold"): self._best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) if self._best_confidence_threshold is None: log.warning("There is no predefined best_confidence_threshold, 0.5 will be used as default.") self._best_confidence_threshold = 0.5 return self._best_confidence_threshold
[docs] def get_dummy_input(self, batch_size: int = 1) -> DetBatchDataEntity: """Returns a dummy input for detection model.""" if self.input_size is None: msg = f"Input size attribute is not set for {self.__class__}" raise ValueError(msg) images = [torch.rand(3, *self.input_size) for _ in range(batch_size)] infos = [] for i, img in enumerate(images): infos.append( ImageInfo( img_idx=i, img_shape=img.shape, ori_shape=img.shape, ), ) return DetBatchDataEntity(batch_size, images, infos, bboxes=[], labels=[])
[docs] class ExplainableOTXDetModel(OTXDetectionModel): """OTX detection model which can attach a XAI (Explainable AI) branch.""" def __init__( self, model_name: str, label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MeanAveragePrecisionFMeasureCallable, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False), ) -> None: from otx.algo.explain.explain_algo import feature_vector_fn super().__init__( model_name=model_name, label_info=label_info, input_size=input_size, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, tile_config=tile_config, ) self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn()
[docs] def forward_explain(self, inputs: DetBatchDataEntity) -> DetBatchPredEntity: """Model forward function.""" from otx.algo.explain.explain_algo import feature_vector_fn if isinstance(inputs, OTXTileBatchDataEntity): return self.forward_tiles(inputs) self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn() # If customize_inputs is overridden outputs = ( self._forward_explain_detection(self.model, **self._customize_inputs(inputs)) if self._customize_inputs != ExplainableOTXDetModel._customize_inputs else self._forward_explain_detection(self.model, inputs) ) return ( self._customize_outputs(outputs, inputs) if self._customize_outputs != ExplainableOTXDetModel._customize_outputs else outputs["predictions"] )
@staticmethod def _forward_explain_detection( self: SingleStageDetector, entity: DetBatchDataEntity, mode: str = "tensor", ) -> dict[str, torch.Tensor]: """Forward func of the BaseDetector instance, which located in is in ExplainableOTXDetModel().model.""" backbone_feat = self.extract_feat(entity.images) bbox_head_feat = self.bbox_head.forward(backbone_feat) # Process the first output form bbox detection head: classification scores feature_vector = self.feature_vector_fn(backbone_feat) saliency_map = self.explain_fn(bbox_head_feat[0]) if mode == "predict": predictions = self.bbox_head.predict(backbone_feat, entity) elif mode == "tensor": predictions = bbox_head_feat else: msg = f'Invalid mode "{mode}".' raise RuntimeError(msg) return { "predictions": predictions, "feature_vector": feature_vector, "saliency_map": saliency_map, }
[docs] def get_explain_fn(self) -> Callable: """Returns explain function.""" from otx.algo.detection.heads.ssd_head import SSDHeadModule from otx.algo.explain.explain_algo import DetClassProbabilityMap # SSD-like heads also have background class background_class = hasattr(self.model, "bbox_head") and isinstance( self.model.bbox_head, SSDHeadModule, ) # TODO (sungchul): revert module's name? tiling_mode = self.tile_config.enable_tiler if hasattr(self, "tile_config") else False explainer = DetClassProbabilityMap( num_classes=self.num_classes + background_class, num_anchors=self.get_num_anchors(), use_cls_softmax=not tiling_mode, ) return explainer.func
[docs] @contextmanager def export_model_forward_context(self) -> Iterator[None]: """A context manager for managing the model's forward function during model exportation. It temporarily modifies the model's forward function to generate output sinks for explain results during the model graph tracing. """ try: self._reset_model_forward() yield finally: self._restore_model_forward()
def _reset_model_forward(self) -> None: if not self.explain_mode: return self.model.explain_fn = self.get_explain_fn() forward_with_explain = self._forward_explain_detection self.original_model_forward = self.model.forward func_type = types.MethodType # Patch class method model_class = type(self.model) model_class.forward = func_type(forward_with_explain, self.model) def _restore_model_forward(self) -> None: if not self.explain_mode: return if not self.original_model_forward: msg = "Original model forward was not saved." raise RuntimeError(msg) func_type = types.MethodType self.model.forward = func_type(self.original_model_forward, self.model) self.original_model_forward = None
[docs] def get_num_anchors(self) -> list[int]: """Gets the anchor configuration from model.""" if hasattr(self.model, "bbox_head") and ( anchor_generator := getattr(self.model.bbox_head, "prior_generator", None) ): return ( anchor_generator.num_base_anchors if hasattr(anchor_generator, "num_base_anchors") else anchor_generator.num_base_priors ) return [1] * 10
[docs] class OVDetectionModel(OVModel[DetBatchDataEntity, DetBatchPredEntity]): """Object detection model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX detection model compatible for OTX testing pipeline. """ def __init__( self, model_name: str, model_type: str = "SSD", async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = MeanAveragePrecisionFMeasureCallable, **kwargs, ) -> None: super().__init__( model_name=model_name, model_type=model_type, async_inference=async_inference, max_num_requests=max_num_requests, use_throughput_mode=use_throughput_mode, model_api_configuration=model_api_configuration, metric=metric, ) def _setup_tiler(self) -> None: """Setup tiler for tile task.""" execution_mode = "async" if self.async_inference else "sync" # Note: Disable async_inference as tiling has its own sync/async implementation self.async_inference = False self.model = DetectionTiler(self.model, execution_mode=execution_mode) log.info( f"Enable tiler with tile size: {self.model.tile_size} \ and overlap: {self.model.tiles_overlap}", ) def _get_hparams_from_adapter(self, model_adapter: OpenvinoAdapter) -> None: """Reads model configuration from ModelAPI OpenVINO adapter. Args: model_adapter (OpenvinoAdapter): target adapter to read the config """ if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]): best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value self.hparams["best_confidence_threshold"] = float(best_confidence_threshold) else: msg = ( "Cannot get best_confidence_threshold from OpenVINO IR's rt_info. " "Please check whether this model is trained by OTX or not. " "Without this information, it can produce a wrong F1 metric score. " "At this time, it will be set as the default value = None." ) log.warning(msg) self.hparams["best_confidence_threshold"] = None def _customize_outputs( self, outputs: list[DetectionResult], inputs: DetBatchDataEntity, ) -> DetBatchPredEntity | OTXBatchLossEntity: # add label index bboxes = [] scores = [] labels = [] # some OMZ model requires to shift labels first_label = ( self.model.model.get_label_name(0) if isinstance(self.model, DetectionTiler) else self.model.get_label_name(0) ) label_shift = 1 if first_label == "background" else 0 if label_shift: log.warning(f"label_shift: {label_shift}") for i, output in enumerate(outputs): output_objects = output.objects if len(output_objects): bbox = [[output.xmin, output.ymin, output.xmax, output.ymax] for output in output_objects] else: bbox = torch.empty(size=(0, 0)) bboxes.append( tv_tensors.BoundingBoxes( bbox, format="XYXY", canvas_size=inputs.imgs_info[i].img_shape, device=self.device, ), ) scores.append(torch.tensor([output.score for output in output_objects], device=self.device)) labels.append(torch.tensor([output.id - label_shift for output in output_objects], device=self.device)) if outputs and outputs[0].saliency_map.size > 1: # Squeeze dim 4D => 3D, (1, num_classes, H, W) => (num_classes, H, W) predicted_s_maps = [out.saliency_map[0] for out in outputs] # Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim) predicted_f_vectors = [out.feature_vector[0] for out in outputs] return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, labels=labels, saliency_map=predicted_s_maps, feature_vector=predicted_f_vectors, ) return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, labels=labels, ) def _convert_pred_entity_to_compute_metric( self, preds: DetBatchPredEntity, inputs: DetBatchDataEntity, ) -> MetricInput: return { "preds": [ { "boxes": bboxes.data, "scores": scores, "labels": labels, } for bboxes, scores, labels in zip( preds.bboxes, preds.scores, preds.labels, ) ], "target": [ { "boxes": bboxes.data, "labels": labels, } for bboxes, labels in zip(inputs.bboxes, inputs.labels) ], } def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) compute_kwargs = {"best_confidence_threshold": best_confidence_threshold} return super()._log_metrics(meter, key, **compute_kwargs)