Source code for otx.algo.classification.timm_model
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""TIMM wrapper model class for OTX."""
from __future__ import annotations
from copy import copy, deepcopy
from math import ceil
from typing import TYPE_CHECKING, Literal
import torch
from torch import nn
from otx.algo.classification.backbones.timm import TimmBackbone
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
HlabelClsBatchPredEntity,
MulticlassClsBatchDataEntity,
MulticlassClsBatchPredEntity,
MultilabelClsBatchDataEntity,
MultilabelClsBatchPredEntity,
)
from otx.core.metrics.accuracy import HLabelClsMetricCallable, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import (
OTXHlabelClsModel,
OTXMulticlassClsModel,
OTXMultilabelClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo, LabelInfoTypes
from otx.core.types.task import OTXTrainType
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
[docs]
class TimmModelForMulticlassCls(OTXMulticlassClsModel):
"""TimmModel for multi-class classification task.
Args:
label_info (LabelInfoTypes): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The training type.
Example:
1. API
>>> model = TimmModelForMulticlassCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<Number-of-classes>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""
def __init__(
self,
label_info: LabelInfoTypes,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.model_name = model_name
self.pretrained = pretrained
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
train_type=train_type,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
return model
def _build_model(self, num_classes: int) -> nn.Module:
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
neck = GlobalAveragePooling(dim=2)
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
return SemiSLClassifier(
backbone=backbone,
neck=neck,
head=SemiSLLinearClsHead(
num_classes=num_classes,
in_channels=backbone.num_features,
),
loss=nn.CrossEntropyLoss(reduction="none"),
)
return ImageClassifier(
backbone=backbone,
neck=neck,
head=LinearClsHead(
num_classes=num_classes,
in_channels=backbone.num_features,
),
loss=nn.CrossEntropyLoss(),
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix)
[docs]
def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return MulticlassClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")
[docs]
class TimmModelForMultilabelCls(OTXMultilabelClsModel):
"""TimmModel for multi-label classification task.
Args:
label_info (LabelInfoTypes): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to MultiLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
Example:
1. API
>>> model = TimmModelForMultilabelCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<Number-of-classes>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForMultilabelCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""
def __init__(
self,
label_info: LabelInfoTypes,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.model_name = model_name
self.pretrained = pretrained
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
return model
def _build_model(self, num_classes: int) -> nn.Module:
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=MultiLabelLinearClsHead(
num_classes=num_classes,
in_channels=backbone.num_features,
normalized=True,
),
loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
loss_scale=7.0,
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix)
[docs]
def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return MultilabelClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")
[docs]
class TimmModelForHLabelCls(OTXHlabelClsModel):
"""Timm Model for hierarchical label classification task.
Args:
label_info (HLabelInfo): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
Example:
1. API
>>> model = TimmModelForHLabelCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<h-label-info>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForHLabelCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""
label_info: HLabelInfo
def __init__(
self,
label_info: HLabelInfo,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.model_name = model_name
self.pretrained = pretrained
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_config = deepcopy(self.label_info.as_head_config_dict())
sample_config["num_classes"] = 5
sample_model_dict = self._build_model(head_config=sample_config).state_dict()
sample_config["num_classes"] = 6
incremental_model_dict = self._build_model(head_config=sample_config).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(head_config=self.label_info.as_head_config_dict())
model.init_weights()
return model
def _build_model(self, head_config: dict) -> nn.Module:
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)
[docs]
def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")
return HlabelClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)
[docs]
def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")