Source code for otx.algo.classification.heads.linear_head

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.

"""Linear Head Implementation.

The original source code is mmpretrain.models.heads.linear_head.LinearClsHead.
you can refer https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/heads/linear_head.py

"""


from __future__ import annotations

import copy

import torch
from torch import nn
from torch.nn import functional

from otx.algo.modules.base_module import BaseModule


[docs] class LinearClsHead(BaseModule): """Linear classifier head. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. cal_acc (bool): Whether to calculate accuracy during training. If you use batch augmentations like Mixup and CutMix during training, it is pointless to calculate accuracy. Defaults to False. init_cfg (dict, optional): the config to control the initialization. Defaults to ``dict(type='Normal', layer='Linear', std=0.01)``. """ def __init__( self, num_classes: int, in_channels: int, init_cfg: dict = {"type": "Normal", "layer": "Linear", "std": 0.01}, # noqa: B006 **kwargs, ): super().__init__(init_cfg=init_cfg) self._is_init = False self.init_cfg = copy.deepcopy(init_cfg) self.in_channels = in_channels self.num_classes = num_classes if self.num_classes <= 0: msg = f"num_classes={num_classes} must be a positive integer" raise ValueError(msg) self.fc = nn.Linear(self.in_channels, self.num_classes)
[docs] def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The forward process.""" if isinstance(feats, tuple): feats = feats[-1] # The final classification head. return self.fc(feats)
[docs] def predict( self, feats: tuple[torch.Tensor], **kwargs, ) -> torch.Tensor: """Inference without augmentation. Args: feats (tuple[Tensor]): The features extracted from the backbone. Multiple stage inputs are acceptable but only the last stage will be used to classify. The shape of every item should be ``(num_samples, num_classes)``. Returns: torch.Tensor: A tensor of softmax result. """ # The part can be traced by torch.fx cls_score = self(feats) # The part can not be traced by torch.fx return self._get_predictions(cls_score)
def _get_predictions(self, cls_score: torch.Tensor) -> torch.Tensor: """Get the score from the classification score.""" return functional.softmax(cls_score, dim=1)