otx.algo.classification.mobilenet_v3#
MobileNetV3 model implementation.
Classes
|
MobileNetV3 Model for hierarchical label classification task. |
|
MobileNetV3ForMulticlassCls is a class that represents a MobileNetV3 model for multiclass classification. |
|
MobileNetV3 Model for multi-class classification task. |
- class otx.algo.classification.mobilenet_v3.MobileNetV3ForHLabelCls(label_info: HLabelInfo, mode: Literal['large', 'small'] = 'large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#
Bases:
OTXHlabelClsModel
MobileNetV3 Model for hierarchical label classification task.
- forward_explain(inputs: HlabelClsBatchDataEntity) HlabelClsBatchPredEntity [source]#
Model forward explain function.
- class otx.algo.classification.mobilenet_v3.MobileNetV3ForMulticlassCls(label_info: LabelInfoTypes, mode: Literal['large', 'small'] = 'large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224), train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#
Bases:
OTXMulticlassClsModel
MobileNetV3ForMulticlassCls is a class that represents a MobileNetV3 model for multiclass classification.
- Parameters:
mode (Literal["large", "small"]) – The mode of the MobileNetV3 model, either “large” or “small”.
num_classes (int) – The number of classes for classification.
loss_callable (Callable[[], nn.Module], optional) – The loss function callable. Defaults to nn.CrossEntropyLoss.
optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.
freeze_backbone (bool, optional) – Whether to freeze the backbone layers during training. Defaults to False.
input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224)
- forward_explain(inputs: MulticlassClsBatchDataEntity) MulticlassClsBatchPredEntity [source]#
Model forward explain function.
- class otx.algo.classification.mobilenet_v3.MobileNetV3ForMultilabelCls(label_info: LabelInfoTypes, mode: Literal['large', 'small'] = 'large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#
Bases:
OTXMultilabelClsModel
MobileNetV3 Model for multi-class classification task.
- forward_explain(inputs: MultilabelClsBatchDataEntity) MultilabelClsBatchPredEntity [source]#
Model forward explain function.