Source code for otx.algo.classification.heads.vision_transformer_head

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.

"""Copy from mmpretrain/models/heads/vision_transformer_head.py."""

from __future__ import annotations

import math
from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional

from otx.algo.modules.base_module import BaseModule, Sequential
from otx.algo.utils.weight_init import trunc_normal_


[docs] class VisionTransformerClsHead(BaseModule): """Vision Transformer classifier head. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. hidden_dim (int, optional): Number of the dimensions for hidden layer. Defaults to None, which means no extra hidden layer. init_cfg (dict): The extra initialization configs. Defaults to ``dict(type='Constant', layer='Linear', val=0)``. """ def __init__( self, num_classes: int, in_channels: int, hidden_dim: int | None = None, init_cfg: dict = {"type": "Constant", "layer": "Linear", "val": 0}, # noqa: B006 **kwargs, ): super().__init__(init_cfg=init_cfg, **kwargs) self.in_channels = in_channels self.num_classes = num_classes self.hidden_dim = hidden_dim if self.num_classes <= 0: msg = f"num_classes={num_classes} must be a positive integer" raise ValueError(msg) self._init_layers() def _init_layers(self) -> None: """Init hidden layer if exists.""" layers: list[tuple[str, nn.Module]] if self.hidden_dim is None: layers = [("head", nn.Linear(self.in_channels, self.num_classes))] else: layers = [ ("pre_logits", nn.Linear(self.in_channels, self.hidden_dim)), ("act", nn.Tanh()), ("head", nn.Linear(self.hidden_dim, self.num_classes)), ] self.layers = Sequential(OrderedDict(layers))
[docs] def init_weights(self) -> None: """Init weights of hidden layer if exists.""" super().init_weights() # Modified from ClassyVision if hasattr(self.layers, "pre_logits"): # Lecun norm trunc_normal_(self.layers.pre_logits.weight, std=math.sqrt(1 / self.layers.pre_logits.in_features)) nn.init.zeros_(self.layers.pre_logits.bias)
[docs] def pre_logits(self, feats: tuple[list[torch.Tensor]]) -> torch.Tensor: """The process before the final classification head. The input ``feats`` is a tuple of list of tensor, and each tensor is the feature of a backbone stage. In ``VisionTransformerClsHead``, we obtain the feature of the last stage and forward in hidden layer if exists. """ feat = feats[-1] # Obtain feature of the last scale. # For backward-compatibility with the previous ViT output cls_token = feat[-1] if isinstance(feat, list) else feat if self.hidden_dim is None: return cls_token x = self.layers.pre_logits(cls_token) return self.layers.act(x)
[docs] def forward(self, feats: tuple[list[torch.Tensor]]) -> torch.Tensor: """The forward process.""" pre_logits = self.pre_logits(feats) # The final classification head. return self.layers.head(pre_logits)
[docs] def predict( self, feats: tuple[torch.Tensor], ) -> torch.Tensor: """Inference without augmentation. Args: feats (tuple[Tensor]): The features extracted from the backbone. Multiple stage inputs are acceptable but only the last stage will be used to classify. The shape of every item should be ``(num_samples, num_classes)``. Returns: torch.Tensor: A tensor of softmax result. """ # The part can be traced by torch.fx cls_score = self(feats) # The part can not be traced by torch.fx return self._get_predictions(cls_score)
def _get_predictions(self, cls_score: torch.Tensor) -> torch.Tensor: """Get the score from the classification score.""" return functional.softmax(cls_score, dim=-1)