Source code for otx.algo.classification.mobilenet_v3

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""MobileNetV3 model implementation."""

from __future__ import annotations

from copy import copy, deepcopy
from math import ceil
from typing import TYPE_CHECKING, Any, Literal

import torch
from torch import Tensor, nn

from otx.algo.classification.backbones import MobileNetV3Backbone
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
    HierarchicalCBAMClsHead,
    LinearClsHead,
    MultiLabelNonLinearClsHead,
    SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
    HlabelClsBatchDataEntity,
    HlabelClsBatchPredEntity,
    MulticlassClsBatchDataEntity,
    MulticlassClsBatchPredEntity,
    MultilabelClsBatchDataEntity,
    MultilabelClsBatchPredEntity,
)
from otx.core.metrics import MetricInput
from otx.core.metrics.accuracy import HLabelClsMetricCallable, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXHlabelClsModel, OTXMulticlassClsModel, OTXMultilabelClsModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo, LabelInfoTypes
from otx.core.types.task import OTXTrainType

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable

    from otx.core.metrics import MetricCallable


[docs] class MobileNetV3ForMulticlassCls(OTXMulticlassClsModel): """MobileNetV3ForMulticlassCls is a class that represents a MobileNetV3 model for multiclass classification. Args: mode (Literal["large", "small"]): The mode of the MobileNetV3 model, either "large" or "small". num_classes (int): The number of classes for classification. loss_callable (Callable[[], nn.Module], optional): The loss function callable. Defaults to nn.CrossEntropyLoss. optimizer (OptimizerCallable, optional): The optimizer callable. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): The metric callable. Defaults to MultiClassClsMetricCallable. torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. freeze_backbone (bool, optional): Whether to freeze the backbone layers during training. Defaults to False. input_size (tuple[int, int], optional): Model input size in the order of height and width. Defaults to (224, 224) """ def __init__( self, label_info: LabelInfoTypes, mode: Literal["large", "small"] = "large", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224), train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: self.mode = mode super().__init__( label_info=label_info, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, input_size=input_size, train_type=train_type, ) def _create_model(self) -> nn.Module: # Get classification_layers for class-incr learning sample_model_dict = self._build_model(num_classes=5).state_dict() incremental_model_dict = self._build_model(num_classes=6).state_dict() self.classification_layers = get_classification_layers( sample_model_dict, incremental_model_dict, prefix="model.", ) model = self._build_model(num_classes=self.num_classes) model.init_weights() return model def _build_model(self, num_classes: int) -> nn.Module: backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size) backbone_out_chennels = MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"] neck = GlobalAveragePooling(dim=2) if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( backbone=backbone, neck=neck, head=SemiSLLinearClsHead( num_classes=num_classes, in_channels=backbone_out_chennels, ), loss=nn.CrossEntropyLoss(reduction="none"), ) return ImageClassifier( backbone=backbone, neck=neck, head=LinearClsHead( num_classes=num_classes, in_channels=backbone_out_chennels, ), loss=nn.CrossEntropyLoss(), )
[docs] def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multiclass", add_prefix)
[docs] def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") return MulticlassClsBatchPredEntity( batch_size=len(outputs["preds"]), images=inputs.images, imgs_info=inputs.imgs_info, labels=outputs["preds"], scores=outputs["scores"], saliency_map=outputs["saliency_map"], feature_vector=outputs["feature_vector"], )
[docs] def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: """Model forward function used for the model tracing during model exportation.""" if self.explain_mode: return self.model(images=image, mode="explain") return self.model(images=image, mode="tensor")
[docs] class MobileNetV3ForMultilabelCls(OTXMultilabelClsModel): """MobileNetV3 Model for multi-class classification task.""" def __init__( self, label_info: LabelInfoTypes, mode: Literal["large", "small"] = "large", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiLabelClsMetricCallable, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224), ) -> None: self.mode = mode super().__init__( label_info=label_info, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, input_size=input_size, ) def _create_model(self) -> nn.Module: # Get classification_layers for class-incr learning sample_model_dict = self._build_model(num_classes=5).state_dict() incremental_model_dict = self._build_model(num_classes=6).state_dict() self.classification_layers = get_classification_layers( sample_model_dict, incremental_model_dict, prefix="model.", ) model = self._build_model(num_classes=self.num_classes) model.init_weights() return model def _build_model(self, num_classes: int) -> nn.Module: backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size) return ImageClassifier( backbone=backbone, neck=GlobalAveragePooling(dim=2), head=MultiLabelNonLinearClsHead( num_classes=num_classes, in_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"], hid_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["hid_channels"], normalized=True, activation=nn.PReLU(), ), loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), loss_scale=7.0, )
[docs] def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multilabel", add_prefix)
def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: if self.training: mode = "loss" elif self.explain_mode: mode = "explain" else: mode = "predict" return { "images": inputs.stacked_images, "labels": torch.stack(inputs.labels), "imgs_info": inputs.imgs_info, "mode": mode, } def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: MultilabelClsBatchDataEntity, ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: if self.training: return OTXBatchLossEntity(loss=outputs) # To list, batch-wise logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] scores = torch.unbind(logits, 0) return MultilabelClsBatchPredEntity( batch_size=inputs.batch_size, images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, labels=logits.argmax(-1, keepdim=True).unbind(0), )
[docs] def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") return MultilabelClsBatchPredEntity( batch_size=len(outputs["preds"]), images=inputs.images, imgs_info=inputs.imgs_info, labels=outputs["preds"], scores=outputs["scores"], saliency_map=outputs["saliency_map"], feature_vector=outputs["feature_vector"], )
[docs] def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: """Model forward function used for the model tracing during model exportation.""" if self.explain_mode: return self.model(images=image, mode="explain") return self.model(images=image, mode="tensor")
[docs] class MobileNetV3ForHLabelCls(OTXHlabelClsModel): """MobileNetV3 Model for hierarchical label classification task.""" label_info: HLabelInfo def __init__( self, label_info: HLabelInfo, mode: Literal["large", "small"] = "large", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224), ) -> None: self.mode = mode super().__init__( label_info=label_info, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, input_size=input_size, ) def _create_model(self) -> nn.Module: # Get classification_layers for class-incr learning sample_config = deepcopy(self.label_info.as_head_config_dict()) sample_config["num_classes"] = 5 sample_model_dict = self._build_model(head_config=sample_config).state_dict() sample_config["num_classes"] = 6 incremental_model_dict = self._build_model(head_config=sample_config).state_dict() self.classification_layers = get_classification_layers( sample_model_dict, incremental_model_dict, prefix="model.", ) model = self._build_model(head_config=self.label_info.as_head_config_dict()) model.init_weights() return model def _build_model(self, head_config: dict) -> nn.Module: if not isinstance(self.label_info, HLabelInfo): raise TypeError(self.label_info) copied_head_config = copy(head_config) copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32)) backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size) return HLabelClassifier( backbone=backbone, neck=nn.Identity(), head=HierarchicalCBAMClsHead( in_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"], **copied_head_config, ), multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), )
[docs] def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "hlabel", add_prefix)
def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: if self.training: mode = "loss" elif self.explain_mode: mode = "explain" else: mode = "predict" return { "images": inputs.stacked_images, "labels": torch.stack(inputs.labels), "imgs_info": inputs.imgs_info, "mode": mode, } def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: HlabelClsBatchDataEntity, ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: if self.training: return OTXBatchLossEntity(loss=outputs) # To list, batch-wise if isinstance(outputs, dict): scores = outputs["scores"] labels = outputs["labels"] else: scores = outputs labels = outputs.argmax(-1, keepdim=True) return HlabelClsBatchPredEntity( batch_size=inputs.batch_size, images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, labels=labels, ) def _convert_pred_entity_to_compute_metric( self, preds: HlabelClsBatchPredEntity, inputs: HlabelClsBatchDataEntity, ) -> MetricInput: hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores if hlabel_info.num_multilabel_classes > 0: preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) else: pred_result = _labels return { "preds": pred_result, "target": torch.stack(inputs.labels), }
[docs] def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") return HlabelClsBatchPredEntity( batch_size=len(outputs["preds"]), images=inputs.images, imgs_info=inputs.imgs_info, labels=outputs["preds"], scores=outputs["scores"], saliency_map=outputs["saliency_map"], feature_vector=outputs["feature_vector"], )
[docs] def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: """Model forward function used for the model tracing during model exportation.""" if self.explain_mode: return self.model(images=image, mode="explain") return self.model(images=image, mode="tensor")