Source code for otx.algo.classification.multilabel_models.torchvision_model
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Torchvision model for the OTX classification."""
from __future__ import annotations
from typing import TYPE_CHECKING
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone
from otx.algo.classification.classifier import ImageClassifier
from otx.algo.classification.heads import (
MultiLabelLinearClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.core.metrics.accuracy import MultiLabelClsMetricCallable
from otx.core.model.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.multilabel_classification import (
OTXMultilabelClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes
if TYPE_CHECKING:
import torch
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import nn
from otx.core.metrics import MetricCallable
[docs]
class TVModelMultilabelCls(OTXMultilabelClsModel):
"""Torchvision model for multilabel classification.
Args:
label_info (LabelInfoTypes): Information about the labels.
backbone (TVModelType): Backbone model for feature extraction.
pretrained (bool, optional): Whether to use pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): Optimizer for model training. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Learning rate scheduler.
Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional): Metric for model evaluation. Defaults to MultiLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
input_size (tuple[int, int], optional): Input size of the images. Defaults to (224, 224).
"""
def __init__(
self,
label_info: LabelInfoTypes,
data_input_params: DataInputParams,
model_name: str = "efficientnet_v2_s",
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
model_name=model_name,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
def _create_model(self, num_classes: int | None = None) -> nn.Module:
num_classes = num_classes if num_classes is not None else self.num_classes
backbone = TorchvisionBackbone(backbone=self.model_name)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=MultiLabelLinearClsHead(
num_classes=num_classes,
in_channels=backbone.in_features,
normalized=True,
),
loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
loss_scale=7.0,
)
[docs]
def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")
return self.model(images=image, mode="tensor")