otx.algo.classification.torchvision_model#

Torchvision model for the OTX classification.

Classes

TVModelForHLabelCls(label_info, backbone, ...)

TVModelForHLabelCls class represents a Torchvision model for hierarchical label classification.

TVModelForMulticlassCls(label_info, ...)

Torchvision model for multiclass classification.

TVModelForMultilabelCls(label_info, ...)

Torchvision model for multilabel classification.

class otx.algo.classification.torchvision_model.TVModelForHLabelCls(label_info: HLabelInfo, backbone: TVModelType, pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: OTXHlabelClsModel

TVModelForHLabelCls class represents a Torchvision model for hierarchical label classification.

Parameters:
  • label_info (HLabelInfo) – Information about the hierarchical labels.

  • backbone (TVModelType) – The type of Torchvision backbone model.

  • pretrained (bool, optional) – Whether to use pretrained weights. Defaults to True.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to HLabelClsMetricCallble.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

  • input_size (tuple[int, int], optional) – The input size of the images. Defaults to (224, 224).

backbone#

The type of Torchvision backbone model.

Type:

TVModelType

pretrained#

Whether to use pretrained weights.

Type:

bool

classification_layers#

The classification layers for class-incremental learning.

Type:

nn.Module

forward_explain(inputs: HlabelClsBatchDataEntity) HlabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

class otx.algo.classification.torchvision_model.TVModelForMulticlassCls(label_info: LabelInfoTypes, backbone: TVModelType, pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, input_size: tuple[int, int] = (224, 224))[source]#

Bases: OTXMulticlassClsModel

Torchvision model for multiclass classification.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • backbone (TVModelType) – Backbone model for feature extraction.

  • pretrained (bool, optional) – Whether to use pretrained weights. Defaults to True.

  • optimizer (OptimizerCallable, optional) – Optimizer for model training. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Metric for model evaluation. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

  • train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional) – Type of training. Defaults to OTXTrainType.SUPERVISED.

  • input_size (tuple[int, int], optional) – Input size of the images. Defaults to (224, 224).

backbone#

Backbone model for feature extraction.

Type:

TVModelType

pretrained#

Whether to use pretrained weights.

Type:

bool

classification_layers#

Classification layers for class-incremental learning.

Type:

nn.ModuleDict

forward_explain(inputs: MulticlassClsBatchDataEntity) MulticlassClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

class otx.algo.classification.torchvision_model.TVModelForMultilabelCls(label_info: LabelInfoTypes, backbone: TVModelType, pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: OTXMultilabelClsModel

Torchvision model for multilabel classification.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • backbone (TVModelType) – Backbone model for feature extraction.

  • pretrained (bool, optional) – Whether to use pretrained weights. Defaults to True.

  • optimizer (OptimizerCallable, optional) – Optimizer for model training. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Metric for model evaluation. Defaults to MultiLabelClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

  • input_size (tuple[int, int], optional) – Input size of the images. Defaults to (224, 224).

backbone#

Backbone model for feature extraction.

Type:

TVModelType

pretrained#

Whether to use pretrained weights.

Type:

bool

input_size#

Input size of the images.

Type:

tuple[int, int]

forward_explain(inputs: MultilabelClsBatchDataEntity) MultilabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.