otx.algo.classification.mobilenet_v3#

MobileNetV3 model implementation.

Classes

MobileNetV3ForHLabelCls(label_info, mode, ...)

MobileNetV3 Model for hierarchical label classification task.

MobileNetV3ForMulticlassCls(label_info, ...)

MobileNetV3ForMulticlassCls is a class that represents a MobileNetV3 model for multiclass classification.

MobileNetV3ForMultilabelCls(label_info, ...)

MobileNetV3 Model for multi-class classification task.

class otx.algo.classification.mobilenet_v3.MobileNetV3ForHLabelCls(label_info: HLabelInfo, mode: Literal['large', 'small'] = 'large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: OTXHlabelClsModel

MobileNetV3 Model for hierarchical label classification task.

forward_explain(inputs: HlabelClsBatchDataEntity) HlabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.mobilenet_v3.MobileNetV3ForMulticlassCls(label_info: LabelInfoTypes, mode: Literal['large', 'small'] = 'large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224), train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#

Bases: OTXMulticlassClsModel

MobileNetV3ForMulticlassCls is a class that represents a MobileNetV3 model for multiclass classification.

Parameters:
  • mode (Literal["large", "small"]) – The mode of the MobileNetV3 model, either “large” or “small”.

  • num_classes (int) – The number of classes for classification.

  • loss_callable (Callable[[], nn.Module], optional) – The loss function callable. Defaults to nn.CrossEntropyLoss.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

  • freeze_backbone (bool, optional) – Whether to freeze the backbone layers during training. Defaults to False.

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224)

forward_explain(inputs: MulticlassClsBatchDataEntity) MulticlassClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.mobilenet_v3.MobileNetV3ForMultilabelCls(label_info: LabelInfoTypes, mode: Literal['large', 'small'] = 'large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: OTXMultilabelClsModel

MobileNetV3 Model for multi-class classification task.

forward_explain(inputs: MultilabelClsBatchDataEntity) MultilabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.