Source code for otx.algo.classification.efficientnet
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""EfficientNet-B0 model implementation."""
from __future__ import annotations
from copy import copy, deepcopy
from math import ceil
from typing import TYPE_CHECKING, Literal
from torch import Tensor, nn
from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, EfficientNetBackbone
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
HlabelClsBatchPredEntity,
MulticlassClsBatchDataEntity,
MulticlassClsBatchPredEntity,
MultilabelClsBatchDataEntity,
MultilabelClsBatchPredEntity,
)
from otx.core.metrics.accuracy import HLabelClsMetricCallable, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXHlabelClsModel, OTXMulticlassClsModel, OTXMultilabelClsModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo, LabelInfoTypes
from otx.core.types.task import OTXTrainType
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
[docs]
class EfficientNetForMulticlassCls(OTXMulticlassClsModel):
"""EfficientNet Model for multi-class classification task."""
def __init__(
self,
label_info: LabelInfoTypes,
version: EFFICIENTNET_VERSION = "b0",
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.version = version
self.pretrained = pretrained
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
train_type=train_type,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
return model
def _build_model(self, num_classes: int) -> nn.Module:
backbone = EfficientNetBackbone(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
neck = GlobalAveragePooling(dim=2)
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
return SemiSLClassifier(
backbone=backbone,
neck=neck,
head=SemiSLLinearClsHead(
num_classes=num_classes,
in_channels=backbone.num_features,
),
loss=nn.CrossEntropyLoss(reduction="none"),
)
return ImageClassifier(
backbone=backbone,
neck=neck,
head=LinearClsHead(
num_classes=num_classes,
in_channels=backbone.num_features,
),
loss=nn.CrossEntropyLoss(),
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)
[docs]
def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return MulticlassClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")
[docs]
class EfficientNetForMultilabelCls(OTXMultilabelClsModel):
"""EfficientNet Model for multi-label classification task."""
def __init__(
self,
label_info: LabelInfoTypes,
version: EFFICIENTNET_VERSION = "b0",
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.version = version
self.pretrained = pretrained
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
return model
def _build_model(self, num_classes: int) -> nn.Module:
backbone = EfficientNetBackbone(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=MultiLabelLinearClsHead(
num_classes=num_classes,
in_channels=backbone.num_features,
normalized=True,
),
loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
loss_scale=7.0,
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multilabel", add_prefix)
[docs]
def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return MultilabelClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")
[docs]
class EfficientNetForHLabelCls(OTXHlabelClsModel):
"""EfficientNetB0 Model for hierarchical label classification task."""
label_info: HLabelInfo
def __init__(
self,
label_info: HLabelInfo,
version: EFFICIENTNET_VERSION = "b0",
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.version = version
self.pretrained = pretrained
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_config = deepcopy(self.label_info.as_head_config_dict())
sample_config["num_classes"] = 5
sample_model_dict = self._build_model(head_config=sample_config).state_dict()
sample_config["num_classes"] = 6
incremental_model_dict = self._build_model(head_config=sample_config).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(head_config=self.label_info.as_head_config_dict())
model.init_weights()
return model
def _build_model(self, head_config: dict) -> nn.Module:
if not isinstance(self.label_info, HLabelInfo):
raise TypeError(self.label_info)
backbone = EfficientNetBackbone(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "hlabel", add_prefix)
[docs]
def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return HlabelClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")