otx.core.model.detection#

Class definition for detection model entity used in OTX.

Classes

ExplainableOTXDetModel(label_info, ...[, ...])

OTX detection model which can attach a XAI (Explainable AI) branch.

OTXDetectionModel(label_info, input_size, ...)

Base class for the detection models used in OTX.

OVDetectionModel(model_name, model_type, ...)

Object detection model compatible for OpenVINO IR inference.

class otx.core.model.detection.ExplainableOTXDetModel(label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXDetectionModel

OTX detection model which can attach a XAI (Explainable AI) branch.

export_model_forward_context() Iterator[None][source]#

A context manager for managing the model’s forward function during model exportation.

It temporarily modifies the model’s forward function to generate output sinks for explain results during the model graph tracing.

forward_explain(inputs: DetBatchDataEntity) DetBatchPredEntity[source]#

Model forward function.

get_explain_fn() Callable[source]#

Returns explain function.

get_num_anchors() list[int][source]#

Gets the anchor configuration from model.

class otx.core.model.detection.OTXDetectionModel(label_info: LabelInfoTypes, input_size: tuple[int, int] | None = None, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _null_metric_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False), train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#

Bases: OTXModel[DetBatchDataEntity, DetBatchPredEntity]

Base class for the detection models used in OTX.

forward_for_tracing(inputs: Tensor) list[InstanceData][source]#

Forward function for export.

forward_tiles(inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) DetBatchPredEntity[source]#

Unpack detection tiles.

Parameters:

inputs (TileBatchDetDataEntity) – Tile batch data entity.

Returns:

Merged detection prediction.

Return type:

DetBatchPredEntity

get_classification_layers(prefix: str = 'model.') dict[str, dict[str, int]][source]#

Get final classification layer information for incremental learning case.

get_dummy_input(batch_size: int = 1) DetBatchDataEntity[source]#

Returns a dummy input for detection model.

on_load_checkpoint(ckpt: dict[str, Any]) None[source]#

Load state_dict from checkpoint.

For detection, it is need to update confidence threshold information when the metric is FMeasure.

predict_step(batch: DetBatchDataEntity, batch_idx: int, dataloader_idx: int = 0) DetBatchPredEntity[source]#

Step function called during PyTorch Lightning Trainer’s predict.

test_step(batch: DetBatchDataEntity, batch_idx: int) None[source]#

Perform a single test step on a batch of data from the test set.

Parameters:
  • batch – A batch of data (a tuple) containing the input tensor of images and target labels.

  • batch_idx – The index of the current batch.

property best_confidence_threshold: float#

Best confidence threshold to filter outputs.

class otx.core.model.detection.OVDetectionModel(model_name: str, model_type: str = 'SSD', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, ~typing.Any] | None = None, metric: ~typing.Callable[[~otx.core.types.label.LabelInfo], ~torchmetrics.metric.Metric | ~torchmetrics.collections.MetricCollection] = <function _mean_ap_f_measure_callable>, **kwargs)[source]#

Bases: OVModel[DetBatchDataEntity, DetBatchPredEntity]

Object detection model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX detection model compatible for OTX testing pipeline.