Source code for otx.algo.classification.mobilenet_v3
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""MobileNetV3 model implementation."""
from __future__ import annotations
from copy import copy, deepcopy
from math import ceil
from typing import TYPE_CHECKING, Any, Literal
import torch
from torch import Tensor, nn
from otx.algo.classification.backbones import MobileNetV3Backbone
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
HlabelClsBatchPredEntity,
MulticlassClsBatchDataEntity,
MulticlassClsBatchPredEntity,
MultilabelClsBatchDataEntity,
MultilabelClsBatchPredEntity,
)
from otx.core.metrics import MetricInput
from otx.core.metrics.accuracy import HLabelClsMetricCallable, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXHlabelClsModel, OTXMulticlassClsModel, OTXMultilabelClsModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo, LabelInfoTypes
from otx.core.types.task import OTXTrainType
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
[docs]
class MobileNetV3ForMulticlassCls(OTXMulticlassClsModel):
"""MobileNetV3ForMulticlassCls is a class that represents a MobileNetV3 model for multiclass classification.
Args:
mode (Literal["large", "small"]): The mode of the MobileNetV3 model, either "large" or "small".
num_classes (int): The number of classes for classification.
loss_callable (Callable[[], nn.Module], optional): The loss function callable. Defaults to nn.CrossEntropyLoss.
optimizer (OptimizerCallable, optional): The optimizer callable. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional): The metric callable. Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
freeze_backbone (bool, optional): Whether to freeze the backbone layers during training. Defaults to False.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (224, 224)
"""
def __init__(
self,
label_info: LabelInfoTypes,
mode: Literal["large", "small"] = "large",
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.mode = mode
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
train_type=train_type,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
return model
def _build_model(self, num_classes: int) -> nn.Module:
backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size)
backbone_out_chennels = MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"]
neck = GlobalAveragePooling(dim=2)
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
return SemiSLClassifier(
backbone=backbone,
neck=neck,
head=SemiSLLinearClsHead(
num_classes=num_classes,
in_channels=backbone_out_chennels,
),
loss=nn.CrossEntropyLoss(reduction="none"),
)
return ImageClassifier(
backbone=backbone,
neck=neck,
head=LinearClsHead(
num_classes=num_classes,
in_channels=backbone_out_chennels,
),
loss=nn.CrossEntropyLoss(),
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multiclass", add_prefix)
[docs]
def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return MulticlassClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")
[docs]
class MobileNetV3ForMultilabelCls(OTXMultilabelClsModel):
"""MobileNetV3 Model for multi-class classification task."""
def __init__(
self,
label_info: LabelInfoTypes,
mode: Literal["large", "small"] = "large",
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.mode = mode
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
return model
def _build_model(self, num_classes: int) -> nn.Module:
backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=MultiLabelNonLinearClsHead(
num_classes=num_classes,
in_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"],
hid_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["hid_channels"],
normalized=True,
activation=nn.PReLU(),
),
loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
loss_scale=7.0,
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multilabel", add_prefix)
def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]:
if self.training:
mode = "loss"
elif self.explain_mode:
mode = "explain"
else:
mode = "predict"
return {
"images": inputs.stacked_images,
"labels": torch.stack(inputs.labels),
"imgs_info": inputs.imgs_info,
"mode": mode,
}
def _customize_outputs(
self,
outputs: Any, # noqa: ANN401
inputs: MultilabelClsBatchDataEntity,
) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity:
if self.training:
return OTXBatchLossEntity(loss=outputs)
# To list, batch-wise
logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"]
scores = torch.unbind(logits, 0)
return MultilabelClsBatchPredEntity(
batch_size=inputs.batch_size,
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
labels=logits.argmax(-1, keepdim=True).unbind(0),
)
[docs]
def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return MultilabelClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")
[docs]
class MobileNetV3ForHLabelCls(OTXHlabelClsModel):
"""MobileNetV3 Model for hierarchical label classification task."""
label_info: HLabelInfo
def __init__(
self,
label_info: HLabelInfo,
mode: Literal["large", "small"] = "large",
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.mode = mode
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_config = deepcopy(self.label_info.as_head_config_dict())
sample_config["num_classes"] = 5
sample_model_dict = self._build_model(head_config=sample_config).state_dict()
sample_config["num_classes"] = 6
incremental_model_dict = self._build_model(head_config=sample_config).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(head_config=self.label_info.as_head_config_dict())
model.init_weights()
return model
def _build_model(self, head_config: dict) -> nn.Module:
if not isinstance(self.label_info, HLabelInfo):
raise TypeError(self.label_info)
copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size)
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"],
**copied_head_config,
),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "hlabel", add_prefix)
def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]:
if self.training:
mode = "loss"
elif self.explain_mode:
mode = "explain"
else:
mode = "predict"
return {
"images": inputs.stacked_images,
"labels": torch.stack(inputs.labels),
"imgs_info": inputs.imgs_info,
"mode": mode,
}
def _customize_outputs(
self,
outputs: Any, # noqa: ANN401
inputs: HlabelClsBatchDataEntity,
) -> HlabelClsBatchPredEntity | OTXBatchLossEntity:
if self.training:
return OTXBatchLossEntity(loss=outputs)
# To list, batch-wise
if isinstance(outputs, dict):
scores = outputs["scores"]
labels = outputs["labels"]
else:
scores = outputs
labels = outputs.argmax(-1, keepdim=True)
return HlabelClsBatchPredEntity(
batch_size=inputs.batch_size,
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
labels=labels,
)
def _convert_pred_entity_to_compute_metric(
self,
preds: HlabelClsBatchPredEntity,
inputs: HlabelClsBatchDataEntity,
) -> MetricInput:
hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment]
_labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels
_scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores
if hlabel_info.num_multilabel_classes > 0:
preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads]
preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :]
pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1)
else:
pred_result = _labels
return {
"preds": pred_result,
"target": torch.stack(inputs.labels),
}
[docs]
def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return HlabelClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")