Source code for otx.algo.classification.multiclass_models.efficientnet
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""EfficientNet-B0 model implementation."""
from __future__ import annotations
from typing import TYPE_CHECKING
from torch import Tensor, nn
from otx.algo.classification.backbones.efficientnet import EfficientNetBackbone
from otx.algo.classification.classifier import ImageClassifier
from otx.algo.classification.heads import LinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.multiclass_classification import OTXMulticlassClsModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
[docs]
class EfficientNetMulticlassCls(OTXMulticlassClsModel):
"""EfficientNet Model for multi-class classification task.
Args:
label_info (LabelInfoTypes): Information about the labels.
data_input_params (DataInputParams): Parameters for data input.
model_name (str, optional): Name of the EfficientNet model variant.
Defaults to "efficientnet_b0".
optimizer (OptimizerCallable, optional): Callable for the optimizer.
Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler.
Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional): Callable for the evaluation metric.
Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional): Flag to indicate whether to use torch.compile. Defaults to False.
"""
def __init__(
self,
label_info: LabelInfoTypes,
data_input_params: DataInputParams,
model_name: str = "efficientnet_b0",
freeze_backbone: bool = False,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
model_name=model_name,
freeze_backbone=freeze_backbone,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
def _create_model(self, num_classes: int | None = None) -> nn.Module:
num_classes = num_classes if num_classes is not None else self.num_classes
backbone = EfficientNetBackbone(model_name=self.model_name, input_size=self.data_input_params.input_size)
neck = GlobalAveragePooling(dim=2)
model = ImageClassifier(
backbone=backbone,
neck=neck,
head=LinearClsHead(
num_classes=num_classes,
in_channels=backbone.num_features,
),
loss=nn.CrossEntropyLoss(),
)
model.init_weights()
return model
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)
[docs]
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")