otx.core.model.classification#

Class definition for classification model entity used in OTX.

Classes

OTXHlabelClsModel(label_info, input_size, ...)

H-label classification models used in OTX.

OTXMulticlassClsModel(label_info, ...)

Base class for the classification models used in OTX.

OTXMultilabelClsModel(label_info, ...)

Multi-label classification models used in OTX.

OVHlabelClassificationModel(model_name, ...)

Hierarchical classification model compatible for OpenVINO IR inference.

OVMulticlassClassificationModel(model_name, ...)

Classification model compatible for OpenVINO IR inference.

OVMultilabelClassificationModel(model_name, ...)

Multilabel classification model compatible for OpenVINO IR inference.

class otx.core.model.classification.OTXHlabelClsModel(label_info: HLabelInfo, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#

Bases: OTXModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity]

H-label classification models used in OTX.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

get_dummy_input(batch_size: int = 1) HlabelClsBatchDataEntity[source]#

Returns a dummy input for classification OV model.

class otx.core.model.classification.OTXMulticlassClsModel(label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#

Bases: OTXModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity]

Base class for the classification models used in OTX.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

get_dummy_input(batch_size: int = 1) MulticlassClsBatchDataEntity[source]#

Returns a dummy input for classification model.

training_step(batch: MulticlassClsBatchDataEntity, batch_idx: int) Tensor[source]#

Performs a single training step on a batch of data.

class otx.core.model.classification.OTXMultilabelClsModel(label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False)[source]#

Bases: OTXModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity]

Multi-label classification models used in OTX.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

get_dummy_input(batch_size: int = 1) MultilabelClsBatchDataEntity[source]#

Returns a dummy input for classification OV model.

class otx.core.model.classification.OVHlabelClassificationModel(model_name: str, model_type: str = 'Classification', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _mixed_hlabel_accuracy>, **kwargs)[source]#

Bases: OVModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity]

Hierarchical classification model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX classification model compatible for OTX testing pipeline.

class otx.core.model.classification.OVMulticlassClassificationModel(model_name: str, model_type: str = 'Classification', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = False, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _multi_class_cls_metric_callable>, **kwargs)[source]#

Bases: OVModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity]

Classification model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX classification model compatible for OTX testing pipeline.

class otx.core.model.classification.OVMultilabelClassificationModel(model_name: str, model_type: str = 'Classification', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _multi_label_cls_metric_callable>, **kwargs)[source]#

Bases: OVModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity]

Multilabel classification model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX classification model compatible for OTX testing pipeline.