Source code for otx.algo.classification.heads.hlabel_cls_head

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for defining h-label linear classification head."""

from __future__ import annotations

from typing import Callable, Sequence

import torch
from torch import nn

from otx.algo.modules.base_module import BaseModule
from otx.algo.utils.weight_init import constant_init, normal_init


class HierarchicalClsHead(BaseModule):
    """The classification head for hierarchical classification.

    This class defines the methods for pre-processing the features,
    calculating the loss, and making predictions for hierarchical classification.
    """

    def __init__(
        self,
        num_multiclass_heads: int,
        num_multilabel_classes: int,
        head_idx_to_logits_range: dict[str, tuple[int, int]],
        num_single_label_classes: int,
        empty_multiclass_head_indices: list[int],
        in_channels: int,
        num_classes: int,
        thr: float = 0.5,
        init_cfg: dict | None = None,
        **kwargs,
    ):
        super().__init__(init_cfg=init_cfg)
        self.num_multiclass_heads = num_multiclass_heads
        self.num_multilabel_classes = num_multilabel_classes
        self.head_idx_to_logits_range = head_idx_to_logits_range
        self.num_single_label_classes = num_single_label_classes
        self.empty_multiclass_head_indices = empty_multiclass_head_indices
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.thr = thr

        if self.num_multiclass_heads == 0:
            msg = "num_multiclass_head should be larger than 0"
            raise ValueError(msg)

    def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
        """The process before the final classification head."""
        if isinstance(feats, Sequence):
            return feats[-1]
        return feats

    def _get_head_idx_to_logits_range(self, idx: int) -> tuple[int, int]:
        """Get head_idx_to_logits_range information from hlabel information."""
        return (
            self.head_idx_to_logits_range[str(idx)][0],
            self.head_idx_to_logits_range[str(idx)][1],
        )

    def predict(
        self,
        feats: tuple[torch.Tensor],
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Inference without augmentation.

        Args:
            feats (tuple[Tensor]): The features extracted from the backbone.
                Multiple stage inputs are acceptable but only the last stage
                will be used to classify. The shape of every item should be
                ``(num_samples, num_classes)``.

        Returns:
            List[DataSample]: A list of data samples which contains the
            predicted results.
        """
        cls_scores = self(feats)
        return self._get_predictions(cls_scores)

    def _get_predictions(
        self,
        cls_scores: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """Post-process the output of head.

        Including softmax and set ``pred_label`` of data samples.
        """
        # Multiclass
        multiclass_pred_scores: list | torch.Tensor = []
        multiclass_pred_labels: list | torch.Tensor = []
        for i in range(self.num_multiclass_heads):
            logit_range = self._get_head_idx_to_logits_range(i)
            multiclass_logit = cls_scores[:, logit_range[0] : logit_range[1]]
            multiclass_pred = torch.softmax(multiclass_logit, dim=1)
            multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1)

            multiclass_pred_scores.append(multiclass_pred_score.view(-1, 1))
            multiclass_pred_labels.append(multiclass_pred_label.view(-1, 1))

        multiclass_pred_scores = torch.cat(multiclass_pred_scores, dim=1)
        multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1)

        if self.num_multilabel_classes > 0:
            multilabel_logits = cls_scores[:, self.num_single_label_classes :]

            multilabel_pred = torch.sigmoid(multilabel_logits)
            multilabel_pred_labels = (multilabel_pred >= self.thr).int()

            pred_scores = torch.cat([multiclass_pred_scores, multilabel_pred], axis=1)
            pred_labels = torch.cat([multiclass_pred_labels, multilabel_pred_labels], axis=1)
        else:
            pred_scores = multiclass_pred_scores
            pred_labels = multiclass_pred_labels

        return {
            "scores": pred_scores,
            "labels": pred_labels,
        }


[docs] class HierarchicalLinearClsHead(HierarchicalClsHead): """Custom classification linear head for hierarchical classification task. Args: num_multiclass_heads (int): Number of multi-class heads. num_multilabel_classes (int): Number of multi-label classes. head_idx_to_logits_range: the logit range of each heads num_single_label_classes: the number of single label classes empty_multiclass_head_indices: the index of head that doesn't include any label due to the label removing in_channels (int): Number of channels in the input feature map. num_classes (int): Number of total classes. thr (float | None): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. """ def __init__( self, num_multiclass_heads: int, num_multilabel_classes: int, head_idx_to_logits_range: dict[str, tuple[int, int]], num_single_label_classes: int, empty_multiclass_head_indices: list[int], in_channels: int, num_classes: int, thr: float = 0.5, init_cfg: dict | None = None, **kwargs, ): super().__init__( num_multiclass_heads=num_multiclass_heads, num_multilabel_classes=num_multilabel_classes, head_idx_to_logits_range=head_idx_to_logits_range, num_single_label_classes=num_single_label_classes, empty_multiclass_head_indices=empty_multiclass_head_indices, in_channels=in_channels, num_classes=num_classes, thr=thr, init_cfg=init_cfg, **kwargs, ) self.fc = nn.Linear(self.in_channels, self.num_classes) self._init_layers() def _init_layers(self) -> None: """Initialize weights of the layers.""" normal_init(self.fc, mean=0, std=0.01, bias=0)
[docs] def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The forward process.""" pre_logits = self.pre_logits(feats) return self.fc(pre_logits)
[docs] class HierarchicalNonLinearClsHead(HierarchicalClsHead): """Custom classification non-linear head for hierarchical classification task. Args: num_multiclass_heads (int): Number of multi-class heads. num_multilabel_classes (int): Number of multi-label classes. head_idx_to_logits_range: the logit range of each heads num_single_label_classes: the number of single label classes empty_multiclass_head_indices: the index of head that doesn't include any label due to the label removing in_channels (int): Number of channels in the input feature map. num_classes (int): Number of total classes. thr (float | None): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. hid_cahnnels (int): Number of channels in the hidden feature map at the classifier. acivation_Cfg (dict | None): Config of activation layer at the classifier. dropout (bool): Flag for the enabling the dropout at the classifier. """ def __init__( self, num_multiclass_heads: int, num_multilabel_classes: int, head_idx_to_logits_range: dict[str, tuple[int, int]], num_single_label_classes: int, empty_multiclass_head_indices: list[int], in_channels: int, num_classes: int, thr: float = 0.5, hid_channels: int = 1280, activation: Callable[[], nn.Module] = nn.ReLU, dropout: bool = False, init_cfg: dict | None = None, **kwargs, ): super().__init__( num_multiclass_heads=num_multiclass_heads, num_multilabel_classes=num_multilabel_classes, head_idx_to_logits_range=head_idx_to_logits_range, num_single_label_classes=num_single_label_classes, empty_multiclass_head_indices=empty_multiclass_head_indices, in_channels=in_channels, num_classes=num_classes, thr=thr, init_cfg=init_cfg, **kwargs, ) self.hid_channels = hid_channels self.dropout = dropout self.activation = activation classifier_modules = [ nn.Linear(in_channels, hid_channels), nn.BatchNorm1d(hid_channels), self.activation if isinstance(self.activation, nn.Module) else self.activation(), ] if self.dropout: classifier_modules.append(nn.Dropout(p=0.2)) classifier_modules.append(nn.Linear(hid_channels, num_classes)) self.classifier = nn.Sequential(*classifier_modules) self._init_layers() def _init_layers(self) -> None: """Iniitialize weights of classification head.""" for module in self.classifier: if isinstance(module, nn.Linear): normal_init(module, mean=0, std=0.01, bias=0) elif isinstance(module, nn.BatchNorm1d): constant_init(module, 1)
[docs] def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The forward process.""" pre_logits = self.pre_logits(feats) return self.classifier(pre_logits)
class ChannelAttention(nn.Module): """Channel attention module that uses average and max pooling to enhance important channels.""" def __init__(self, in_channels: int, reduction: int = 16): """Initializes the ChannelAttention module.""" super().__init__() self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False) self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies channel attention to the input tensor.""" avg_out = self.fc2(torch.relu(self.fc1(torch.mean(x, dim=2, keepdim=True).mean(dim=3, keepdim=True)))) max_out = self.fc2(torch.relu(self.fc1(torch.max(x, dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]))) return torch.sigmoid(avg_out + max_out) class SpatialAttention(nn.Module): """Spatial attention module that uses average and max pooling to enhance important spatial locations.""" def __init__(self, kernel_size: int = 7): """Initializes the SpatialAttention module.""" super().__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies spatial attention to the input tensor.""" avg_out = torch.mean(x, dim=1, keepdim=True) max_out = torch.max(x, dim=1, keepdim=True)[0] x = torch.cat([avg_out, max_out], dim=1) return torch.sigmoid(self.conv(x)) class CBAM(nn.Module): """CBAM module that applies channel and spatial attention sequentially.""" def __init__(self, in_channels: int, reduction: int = 16, kernel_size: int = 7): """Initializes the CBAM module with channel and spatial attention.""" super().__init__() self.channel_attention = ChannelAttention(in_channels, reduction) self.spatial_attention = SpatialAttention(kernel_size) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies channel and spatial attention to the input tensor.""" x = x * self.channel_attention(x) return x * self.spatial_attention(x)
[docs] class HierarchicalCBAMClsHead(HierarchicalClsHead): """Custom classification CBAM head for hierarchical classification task. Args: num_multiclass_heads (int): Number of multi-class heads. num_multilabel_classes (int): Number of multi-label classes. head_idx_to_logits_range (dict[str, tuple[int, int]]): the logit range of each heads num_single_label_classes (int): the number of single label classes empty_multiclass_head_indices (list[int]): the index of head that doesn't include any label due to the label removing in_channels (int): Number of channels in the input feature map. num_classes (int): Number of total classes. thr (float, optional): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. step_size (int | tuple[int, int], optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7. """ def __init__( self, num_multiclass_heads: int, num_multilabel_classes: int, head_idx_to_logits_range: dict[str, tuple[int, int]], num_single_label_classes: int, empty_multiclass_head_indices: list[int], in_channels: int, num_classes: int, thr: float = 0.5, init_cfg: dict | None = None, step_size: int | tuple[int, int] = 7, **kwargs, ): super().__init__( num_multiclass_heads=num_multiclass_heads, num_multilabel_classes=num_multilabel_classes, head_idx_to_logits_range=head_idx_to_logits_range, num_single_label_classes=num_single_label_classes, empty_multiclass_head_indices=empty_multiclass_head_indices, in_channels=in_channels, num_classes=num_classes, thr=thr, init_cfg=init_cfg, **kwargs, ) self.step_size = (step_size, step_size) if isinstance(step_size, int) else tuple(step_size) self.fc_superclass = nn.Linear(in_channels * self.step_size[0] * self.step_size[1], num_multiclass_heads) self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * self.step_size[0] * self.step_size[1]) self.cbam = CBAM(in_channels) self.fc_subclass = nn.Linear(in_channels * self.step_size[0] * self.step_size[1], num_classes) self._init_layers()
[docs] def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The process before the final classification head.""" if isinstance(feats, Sequence): feats = feats[-1] return feats.view(feats.size(0), self.in_channels * self.step_size[0] * self.step_size[1])
def _init_layers(self) -> None: """Iniitialize weights of classification head.""" normal_init(self.fc_superclass, mean=0, std=0.01, bias=0) normal_init(self.fc_subclass, mean=0, std=0.01, bias=0)
[docs] def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The forward process.""" pre_logits = self.pre_logits(feats) out_superclass = self.fc_superclass(pre_logits) attention_weights = torch.sigmoid(self.attention_fc(out_superclass)) attended_features = pre_logits * attention_weights attended_features = attended_features.view( pre_logits.size(0), self.in_channels, self.step_size[0], self.step_size[1], ) attended_features = self.cbam(attended_features) attended_features = attended_features.view( pre_logits.size(0), self.in_channels * self.step_size[0] * self.step_size[1], ) return self.fc_subclass(attended_features)