otx.algo.classification.multiclass_models#

multiclass classification models package.

Classes

EfficientNetMulticlassCls(label_info, ...)

EfficientNet Model for multi-class classification task.

MobileNetV3MulticlassCls(label_info, ...)

MobileNetV3MulticlassCls is a class that represents a MobileNetV3 model for multiclass classification.

TimmModelMulticlassCls(label_info, ...)

TimmModel for multi-class classification task.

TVModelMulticlassCls(label_info, ...)

Torchvision model for multiclass classification.

VisionTransformerMulticlassCls(label_info, ...)

DeitTiny Model for multi-class classification task.

class otx.algo.classification.multiclass_models.EfficientNetMulticlassCls(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: str = 'efficientnet_b0', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False)[source]#

Bases: OTXMulticlassClsModel

EfficientNet Model for multi-class classification task.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (str, optional) – Name of the EfficientNet model variant. Defaults to “efficientnet_b0”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the evaluation metric. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate whether to use torch.compile. Defaults to False.

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels used in the model.

  • data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.multiclass_models.MobileNetV3MulticlassCls(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: str = 'mobilenetv3_large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False)[source]#

Bases: OTXMulticlassClsModel

MobileNetV3MulticlassCls is a class that represents a MobileNetV3 model for multiclass classification.

Parameters:
  • label_info (LabelInfoTypes) – The label information.

  • data_input_params (DataInputParams) – The data input parameters such as input size and normalization.

  • model_name (str, optional) – The model name. Defaults to “mobilenetv3_large”.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels used in the model.

  • data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.multiclass_models.TVModelMulticlassCls(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: str = 'efficientnet_v2_s', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False)[source]#

Bases: OTXMulticlassClsModel

Torchvision model for multiclass classification.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • data_input_params (DataInputParams) – Data input parameters such as input size and normalization.

  • model_name (str, optional) – Backbone model name for feature extraction. Defaults to “efficientnet_v2_s”.

  • optimizer (OptimizerCallable, optional) – Optimizer for model training. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Metric for model evaluation. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels used in the model.

  • data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

class otx.algo.classification.multiclass_models.TimmModelMulticlassCls(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: str, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False)[source]#

Bases: OTXMulticlassClsModel

TimmModel for multi-class classification task.

Parameters:
  • label_info (LabelInfoTypes) – The label information for the classification task.

  • model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).

  • pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.

  • optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.

  • metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Example

  1. API
    >>> model = TimmModelForMulticlassCls(
    ...     model_name="tf_efficientnetv2_s.in21k",
    ...     label_info=<Number-of-classes>,
    ... )
    
  2. CLI
    >>> otx train             ... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls             ... --model.model_name tf_efficientnetv2_s.in21k
    

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels used in the model.

  • data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.multiclass_models.VisionTransformerMulticlassCls(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: str = 'vit-tiny', lora: bool = False, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False)[source]#

Bases: ForwardExplainMixInForViT, OTXMulticlassClsModel

DeitTiny Model for multi-class classification task.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.