Source code for otx.algo.classification.vit
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""ViT model implementation."""
from __future__ import annotations
import types
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal
from urllib.parse import urlparse
import numpy as np
import torch
from torch import nn
from torch.hub import download_url_to_file
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
MultiLabelLinearClsHead,
SemiSLVisionTransformerClsHead,
VisionTransformerClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.utils import get_classification_layers
from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.base import T_OTXBatchDataEntity, T_OTXBatchPredEntity
from otx.core.metrics.accuracy import HLabelClsMetricCallable, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import (
OTXHlabelClsModel,
OTXMulticlassClsModel,
OTXMultilabelClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo, LabelInfoTypes
from otx.core.types.task import OTXTrainType
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
augreg_url = "https://storage.googleapis.com/vit_models/augreg/"
dinov2_url = "https://dl.fbaipublicfiles.com/dinov2/"
pretrained_urls = {
"vit-tiny": augreg_url
+ "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz",
"vit-small": augreg_url
+ "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz",
"vit-base": augreg_url
+ "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz",
"vit-large": augreg_url
+ "L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz",
"dinov2-small": dinov2_url + "dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"dinov2-base": dinov2_url + "dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"dinov2-large": dinov2_url + "dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
"dinov2-giant": dinov2_url + "dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth",
}
[docs]
class ForwardExplainMixInForViT(Generic[T_OTXBatchPredEntity, T_OTXBatchDataEntity]):
"""ViT model which can attach a XAI (Explainable AI) branch."""
explain_mode: bool
num_classes: int
model: ImageClassifier
[docs]
@torch.no_grad()
def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
"""Performs model's neck and head forward."""
if not hasattr(self.model.backbone, "blocks"):
raise ValueError
# Part of the last transformer_encoder block (except first LayerNorm)
target_layer = self.model.backbone.blocks[-1]
x = x + target_layer.attn(x)
x = target_layer.mlp(target_layer.norm2(x))
# Final LayerNorm and neck
x = self.model.backbone.norm(x)
if self.model.neck is not None:
x = self.model.neck(x)
# Head
cls_token = x[:, 0]
layer_output = [None, cls_token]
logit = self.model.head.forward(layer_output)
if isinstance(logit, list):
logit = torch.from_numpy(np.array(logit))
return logit
@staticmethod
def _forward_explain_image_classifier(
self: ImageClassifier,
images: torch.Tensor,
mode: str = "tensor",
**kwargs, # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward func of the ImageClassifier instance, which located in is in OTXModel().model."""
backbone = self.backbone
feat = backbone.forward(images, out_type="raw")[-1]
x = (feat[:, 0],)
saliency_map = self.explain_fn(feat)
if self.neck is not None:
x = self.neck(x)
feature_vector = x[-1]
if mode in ("tensor", "explain"):
logits = self.head(x)
elif mode == "predict":
logits = self.head.predict(x)
else:
msg = f'Invalid mode "{mode}".'
raise RuntimeError(msg)
# H-Label Classification Case
pred_results = self.head._get_predictions(logits) # noqa: SLF001
if isinstance(pred_results, dict):
scores = pred_results["scores"]
labels = pred_results["labels"]
else:
scores = pred_results.unbind(0)
labels = logits.argmax(-1, keepdim=True).unbind(0)
outputs = {
"logits": logits,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}
if not torch.jit.is_tracing():
outputs["scores"] = scores
outputs["labels"] = labels
return outputs
[docs]
def get_explain_fn(self) -> Callable:
"""Returns explain function."""
explainer = ViTReciproCAM(
self.head_forward_fn,
num_classes=self.num_classes,
)
return explainer.func
@property
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for DeitTinyForMultilabelCls."""
return {"model_type": "transformer"}
@property
def has_gap(self) -> bool:
"""Defines if GAP is used right after backbone.
Note:
Can be redefined at the model's level.
"""
return True
def _register(self) -> None:
if getattr(self, "_registered", False):
return
self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()
self._registered = True
[docs]
def forward_explain(
self,
inputs: T_OTXBatchDataEntity,
) -> T_OTXBatchPredEntity:
"""Model forward function."""
self._register()
orig_model_forward = self.model.forward
try:
self.model.forward = types.MethodType(self._forward_explain_image_classifier, self.model) # type: ignore[method-assign, assignment]
forward_func: Callable[[T_OTXBatchDataEntity], T_OTXBatchPredEntity] | None = getattr(self, "forward", None)
if forward_func is None:
msg = (
"This instance has no forward function. "
"Did you attach this mixin into a class derived from OTXModel?"
)
raise RuntimeError(msg)
return forward_func(inputs)
finally:
self.model.forward = orig_model_forward # type: ignore[method-assign, assignment]
[docs]
def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
self._register()
forward_explain = types.MethodType(self._forward_explain_image_classifier, self.model)
return forward_explain(images=image, mode="tensor")
return self.model(images=image, mode="tensor")
[docs]
class VisionTransformerForMulticlassCls(ForwardExplainMixInForViT, OTXMulticlassClsModel):
"""DeitTiny Model for multi-class classification task."""
model: ImageClassifier
def __init__(
self,
label_info: LabelInfoTypes,
arch: VIT_ARCH_TYPE = "vit-tiny",
lora: bool = False,
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.arch = arch
self.lora = lora
self.pretrained = pretrained
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
train_type=train_type,
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
for key in list(state_dict.keys()):
new_key = key.replace("patch_embed.projection", "patch_embed.proj")
new_key = new_key.replace("backbone.ln1", "backbone.norm")
new_key = new_key.replace("ffn.layers.0.0", "mlp.fc1")
new_key = new_key.replace("ffn.layers.1", "mlp.fc2")
new_key = new_key.replace("layers", "blocks")
new_key = new_key.replace("ln", "norm")
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
if self.pretrained and self.arch in pretrained_urls:
print(f"init weight - {pretrained_urls[self.arch]}")
parts = urlparse(pretrained_urls[self.arch])
filename = Path(parts.path).name
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
cache_file = cache_dir / filename
if not Path.exists(cache_file):
download_url_to_file(pretrained_urls[self.arch], cache_file, "", progress=True)
model.backbone.load_pretrained(checkpoint_path=cache_file)
return model
def _build_model(self, num_classes: int) -> nn.Module:
init_cfg = [
{"std": 0.2, "layer": "Linear", "type": "TruncNormal"},
{"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"},
]
vit_backbone = VisionTransformer(arch=self.arch, img_size=self.input_size, lora=self.lora)
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
return SemiSLClassifier(
backbone=vit_backbone,
neck=None,
head=SemiSLVisionTransformerClsHead(
num_classes=num_classes,
in_channels=vit_backbone.embed_dim,
),
loss=nn.CrossEntropyLoss(reduction="none"),
)
return ImageClassifier(
backbone=vit_backbone,
neck=None,
head=VisionTransformerClsHead(
num_classes=num_classes,
in_channels=vit_backbone.embed_dim,
),
loss=nn.CrossEntropyLoss(),
init_cfg=init_cfg,
)
[docs]
class VisionTransformerForMultilabelCls(ForwardExplainMixInForViT, OTXMultilabelClsModel):
"""DeitTiny Model for multi-class classification task."""
model: ImageClassifier
def __init__(
self,
label_info: LabelInfoTypes,
arch: VIT_ARCH_TYPE = "vit-tiny",
lora: bool = False,
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.arch = arch
self.lora = lora
self.pretrained = pretrained
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
for key in list(state_dict.keys()):
new_key = key.replace("patch_embed.projection", "patch_embed.proj")
new_key = new_key.replace("backbone.ln1", "backbone.norm")
new_key = new_key.replace("ffn.layers.0.0", "mlp.fc1")
new_key = new_key.replace("ffn.layers.1", "mlp.fc2")
new_key = new_key.replace("layers", "blocks")
new_key = new_key.replace("ln", "norm")
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(num_classes=self.num_classes)
model.init_weights()
if self.pretrained and self.arch in pretrained_urls:
print(f"init weight - {pretrained_urls[self.arch]}")
parts = urlparse(pretrained_urls[self.arch])
filename = Path(parts.path).name
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
cache_file = cache_dir / filename
if not Path.exists(cache_file):
download_url_to_file(pretrained_urls[self.arch], cache_file, "", progress=True)
model.backbone.load_pretrained(checkpoint_path=cache_file)
return model
def _build_model(self, num_classes: int) -> nn.Module:
vit_backbone = VisionTransformer(arch=self.arch, img_size=self.input_size, lora=self.lora)
return ImageClassifier(
backbone=vit_backbone,
neck=None,
head=MultiLabelLinearClsHead(
num_classes=num_classes,
in_channels=vit_backbone.embed_dim,
),
loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
[docs]
class VisionTransformerForHLabelCls(ForwardExplainMixInForViT, OTXHlabelClsModel):
"""DeitTiny Model for hierarchical label classification task."""
model: ImageClassifier
label_info: HLabelInfo
def __init__(
self,
label_info: HLabelInfo,
arch: VIT_ARCH_TYPE = "vit-tiny",
lora: bool = False,
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.arch = arch
self.lora = lora
self.pretrained = pretrained
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
for key in list(state_dict.keys()):
new_key = key.replace("patch_embed.projection", "patch_embed.proj")
new_key = new_key.replace("backbone.ln1", "backbone.norm")
new_key = new_key.replace("ffn.layers.0.0", "mlp.fc1")
new_key = new_key.replace("ffn.layers.1", "mlp.fc2")
new_key = new_key.replace("layers", "blocks")
new_key = new_key.replace("ln", "norm")
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)
def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_config = deepcopy(self.label_info.as_head_config_dict())
sample_config["num_classes"] = 5
sample_model_dict = self._build_model(head_config=sample_config).state_dict()
sample_config["num_classes"] = 6
incremental_model_dict = self._build_model(head_config=sample_config).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)
model = self._build_model(head_config=self.label_info.as_head_config_dict())
model.init_weights()
if self.pretrained and self.arch in pretrained_urls:
print(f"init weight - {pretrained_urls[self.arch]}")
parts = urlparse(pretrained_urls[self.arch])
filename = Path(parts.path).name
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
cache_file = cache_dir / filename
if not Path.exists(cache_file):
download_url_to_file(pretrained_urls[self.arch], cache_file, "", progress=True)
model.backbone.load_pretrained(checkpoint_path=cache_file)
return model
def _build_model(self, head_config: dict) -> nn.Module:
if not isinstance(self.label_info, HLabelInfo):
raise TypeError(self.label_info)
init_cfg = [
{"std": 0.2, "layer": "Linear", "type": "TruncNormal"},
{"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"},
]
vit_backbone = VisionTransformer(arch=self.arch, img_size=self.input_size, lora=self.lora)
return HLabelClassifier(
backbone=vit_backbone,
neck=None,
head=HierarchicalCBAMClsHead(
in_channels=vit_backbone.embed_dim,
step_size=1,
**head_config,
),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
init_cfg=init_cfg,
)