# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""ViT model implementation."""
from __future__ import annotations
import types
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlparse
import numpy as np
import torch
from torch import nn
from torch.hub import download_url_to_file
from otx.algo.classification.backbones.vision_transformer import VisionTransformer
from otx.algo.classification.classifier import ImageClassifier
from otx.algo.classification.heads import (
VisionTransformerClsHead,
)
from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.base import T_OTXBatchDataEntity, T_OTXBatchPredEntity
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.multiclass_classification import (
OTXMulticlassClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
augreg_url = "https://storage.googleapis.com/vit_models/augreg/"
dinov2_url = "https://dl.fbaipublicfiles.com/dinov2/"
pretrained_urls = {
"vit-tiny": augreg_url
+ "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz",
"vit-small": augreg_url
+ "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz",
"vit-base": augreg_url
+ "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz",
"vit-large": augreg_url
+ "L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz",
"dinov2-small": dinov2_url + "dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"dinov2-base": dinov2_url + "dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"dinov2-large": dinov2_url + "dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
"dinov2-giant": dinov2_url + "dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth",
}
class ForwardExplainMixInForViT:
"""ViT model which can attach a XAI (Explainable AI) branch."""
explain_mode: bool
num_classes: int
model: ImageClassifier
@torch.no_grad()
def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
"""Performs model's neck and head forward."""
if not hasattr(self.model.backbone, "blocks"):
raise ValueError
# Part of the last transformer_encoder block (except first LayerNorm)
target_layer = self.model.backbone.blocks[-1]
x = x + target_layer.attn(x)
x = target_layer.mlp(target_layer.norm2(x))
# Final LayerNorm and neck
x = self.model.backbone.norm(x)
if self.model.neck is not None:
x = self.model.neck(x)
# Head
cls_token = x[:, 0]
layer_output = [None, cls_token]
logit = self.model.head.forward(layer_output)
if isinstance(logit, list):
logit = torch.from_numpy(np.array(logit))
return logit
@staticmethod
def _forward_explain_image_classifier(
self: ImageClassifier,
images: torch.Tensor,
mode: str = "tensor",
**kwargs, # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward func of the ImageClassifier instance, which located in is in OTXModel().model."""
backbone = self.backbone
feat = backbone.forward(images, out_type="raw")[-1]
x = (feat[:, 0],)
saliency_map = self.explain_fn(feat)
if self.neck is not None:
x = self.neck(x)
feature_vector = x[-1]
if mode in ("tensor", "explain"):
logits = self.head(x)
elif mode == "predict":
logits = self.head.predict(x)
else:
msg = f'Invalid mode "{mode}".'
raise RuntimeError(msg)
# H-Label Classification Case
pred_results = self.head._get_predictions(logits) # noqa: SLF001
if isinstance(pred_results, dict):
scores = pred_results["scores"]
labels = pred_results["labels"]
else:
scores = pred_results.unbind(0)
labels = logits.argmax(-1, keepdim=True).unbind(0)
outputs = {
"logits": logits,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}
if not torch.jit.is_tracing():
outputs["scores"] = scores
outputs["labels"] = labels
return outputs
def get_explain_fn(self) -> Callable:
"""Returns explain function."""
explainer = ViTReciproCAM(
self.head_forward_fn,
num_classes=self.num_classes,
)
return explainer.func
@property
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for DeitTinyForMultilabelCls."""
return {"model_type": "transformer"}
@property
def has_gap(self) -> bool:
"""Defines if GAP is used right after backbone.
Note:
Can be redefined at the model's level.
"""
return True
def _register(self) -> None:
if getattr(self, "_registered", False):
return
self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()
self._registered = True
def forward_explain(
self,
inputs: T_OTXBatchDataEntity,
) -> T_OTXBatchPredEntity:
"""Model forward function."""
self._register()
orig_model_forward = self.model.forward
try:
self.model.forward = types.MethodType(self._forward_explain_image_classifier, self.model) # type: ignore[method-assign, assignment]
forward_func: Callable[[T_OTXBatchDataEntity], T_OTXBatchPredEntity] | None = getattr(self, "forward", None)
if forward_func is None:
msg = (
"This instance has no forward function. "
"Did you attach this mixin into a class derived from OTXModel?"
)
raise RuntimeError(msg)
return forward_func(inputs)
finally:
self.model.forward = orig_model_forward # type: ignore[method-assign, assignment]
def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
self._register()
forward_explain = types.MethodType(self._forward_explain_image_classifier, self.model)
return forward_explain(images=image, mode="tensor")
return self.model(images=image, mode="tensor")