Source code for otx.algo.detection.heads.atss_head

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""Implementation modified from mmdet.models.dense_heads.atss_head.py.

Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/mmdet/models/dense_heads/atss_head.py
"""

from __future__ import annotations

from functools import partial
from typing import Any, Callable, ClassVar

import torch
from torch import Tensor, nn

from otx.algo.common.utils.coders import BaseBBoxCoder
from otx.algo.common.utils.prior_generators import BasePriorGenerator
from otx.algo.common.utils.utils import multi_apply, reduce_mean
from otx.algo.detection.heads.anchor_head import AnchorHead
from otx.algo.detection.heads.class_incremental_mixin import (
    ClassIncrementalMixin,
)
from otx.algo.detection.utils.prior_generators.utils import anchor_inside_flags
from otx.algo.detection.utils.utils import unmap
from otx.algo.modules.conv_module import Conv2dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.modules.scale import Scale
from otx.algo.utils.mmengine_utils import InstanceData
from otx.core.data.entity.detection import DetBatchDataEntity

EPS = 1e-12


class ATSSHeadModule(ClassIncrementalMixin, AnchorHead):
    """Detection Head of `ATSS <https://arxiv.org/abs/1912.02424>`_.

    ATSS head structure is similar with FCOS, however ATSS use anchor boxes
    and assign label by Adaptive Training Sample Selection instead max-iou.

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
        pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 3.
        stacked_convs (int): Number of stacking convs of the head. Defaults to 4.
        normalization (Callable[..., nn.Module] | None): Normalization layer module.
            Defaults to ``partial(build_norm_layer, nn.GroupNorm, num_groups=32, requires_grad=True)``.
        reg_decoded_bbox (bool): If true, the regression loss would be
            applied directly on decoded bounding boxes, converting both
            the predicted boxes and regression targets to absolute
            coordinates format. Defaults to False. It should be `True` when
            using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
        init_cfg (dict, list[dict], optional): Initialization config dict.
        use_sigmoid_cls (bool): Whether to use a sigmoid activation function
            for classification prediction. Defaults to True.
    """

    def __init__(
        self,
        num_classes: int,
        in_channels: int,
        pred_kernel_size: int = 3,
        stacked_convs: int = 4,
        normalization: Callable[..., nn.Module] = partial(
            build_norm_layer,
            nn.GroupNorm,
            num_groups=32,
            requires_grad=True,
        ),
        reg_decoded_bbox: bool = True,
        init_cfg: dict | None = None,
        use_sigmoid_cls: bool = True,
        **kwargs,
    ) -> None:
        self.pred_kernel_size = pred_kernel_size
        self.stacked_convs = stacked_convs
        self.normalization = normalization
        init_cfg = init_cfg or {
            "type": "Normal",
            "layer": "Conv2d",
            "std": 0.01,
            "override": {"type": "Normal", "name": "atss_cls", "std": 0.01, "bias_prob": 0.01},
        }
        super().__init__(
            num_classes=num_classes,
            in_channels=in_channels,
            reg_decoded_bbox=reg_decoded_bbox,
            init_cfg=init_cfg,
            use_sigmoid_cls=use_sigmoid_cls,
            **kwargs,
        )

        self.sampling = False

    def _init_layers(self) -> None:
        """Initialize layers of the head."""
        self.relu = nn.ReLU(inplace=True)
        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        for i in range(self.stacked_convs):
            chn = self.in_channels if i == 0 else self.feat_channels
            self.cls_convs.append(
                Conv2dModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    normalization=build_norm_layer(self.normalization, num_features=self.feat_channels),
                ),
            )
            self.reg_convs.append(
                Conv2dModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    normalization=build_norm_layer(self.normalization, num_features=self.feat_channels),
                ),
            )
        pred_pad_size = self.pred_kernel_size // 2
        self.atss_cls = nn.Conv2d(
            self.feat_channels,
            self.num_anchors * self.cls_out_channels,
            self.pred_kernel_size,
            padding=pred_pad_size,
        )
        self.atss_reg = nn.Conv2d(
            self.feat_channels,
            self.num_base_priors * 4,
            self.pred_kernel_size,
            padding=pred_pad_size,
        )
        self.atss_centerness = nn.Conv2d(
            self.feat_channels,
            self.num_base_priors * 1,
            self.pred_kernel_size,
            padding=pred_pad_size,
        )
        self.scales = nn.ModuleList([Scale(1.0) for _ in self.prior_generator.strides])

    def forward(self, x: tuple[Tensor]) -> tuple[list[Tensor], ...]:
        """Forward features from the upstream network.

        Args:
            x (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            tuple: Usually a tuple of classification scores and bbox prediction
                cls_scores (list[Tensor]): Classification scores for all scale
                    levels, each is a 4D-tensor, the channels number is
                    num_anchors * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for all scale
                    levels, each is a 4D-tensor, the channels number is
                    num_anchors * 4.
        """
        return multi_apply(self.forward_single, x, self.scales)

    def forward_single(self, x: Tensor, scale: Scale) -> tuple[Tensor, ...]:  # type: ignore[override]
        """Forward feature of a single scale level.

        Args:
            x (Tensor): Features of a single scale level.
            scale (Scale): Learnable scale module to resize the bbox prediction.

        Returns:
            tuple:
                cls_score (Tensor): Cls scores for a single scale level
                    the channels number is num_anchors * num_classes.
                bbox_pred (Tensor): Box energies / deltas for a single scale
                    level, the channels number is num_anchors * 4.
                centerness (Tensor): Centerness for a single scale level, the
                    channel number is (N, num_anchors * 1, H, W).
        """
        cls_feat = x
        reg_feat = x
        for cls_conv in self.cls_convs:
            cls_feat = cls_conv(cls_feat)
        for reg_conv in self.reg_convs:
            reg_feat = reg_conv(reg_feat)
        cls_score = self.atss_cls(cls_feat)
        # we just follow atss, not apply exp in bbox_pred
        bbox_pred = scale(self.atss_reg(reg_feat))
        centerness = self.atss_centerness(reg_feat)
        return cls_score, bbox_pred, centerness

    def prepare_loss_inputs(
        self,
        x: tuple[Tensor],
        entity: DetBatchDataEntity,
    ) -> dict | tuple:
        """Perform forward propagation of the detection head and prepare for loss calculation.

        Args:
            x (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.
            entity (DetBatchDataEntity): Entity from OTX dataset.

        Returns:
            dict: A dictionary of components for loss calculation.
        """
        (cls_scores, bbox_preds, centernesses), batch_gt_instances, batch_img_metas = super().prepare_loss_inputs(
            x,
            entity,
        )
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        if len(featmap_sizes) != self.prior_generator.num_levels:
            msg = "featmap_sizes and self.prior_generator.num_levels have different levels."
            raise ValueError(msg)

        device = cls_scores[0].device
        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes, batch_img_metas, device=device)

        cls_reg_targets = self.get_targets(
            anchor_list,
            valid_flag_list,
            batch_gt_instances,
            batch_img_metas,
        )

        (
            anchor_list,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            avg_factor,
            valid_label_mask,
        ) = cls_reg_targets
        avg_factor = reduce_mean(torch.tensor(avg_factor, dtype=torch.float, device=device)).item()

        return {
            "anchors": anchor_list,
            "cls_score": cls_scores,
            "bbox_pred": bbox_preds,
            "centerness": centernesses,
            "labels": labels_list,
            "label_weights": label_weights_list,
            "bbox_targets": bbox_targets_list,
            "valid_label_mask": valid_label_mask,
            "avg_factor": avg_factor,
        }

    def get_targets(
        self,
        anchor_list: list[list[Tensor]],
        valid_flag_list: list[list[Tensor]],
        batch_gt_instances: list[InstanceData],
        batch_img_metas: list[dict],
        batch_gt_instances_ignore: list[InstanceData] | None = None,
        unmap_outputs: bool = True,
    ) -> tuple:
        """Get targets for Detection head.

        This method is almost the same as `AnchorHead.get_targets()`. Besides
        returning the targets as the parent method does, it also returns the
        anchors as the first element of the returned tuple.
        """
        return self.get_atss_targets(
            anchor_list,
            valid_flag_list,
            batch_gt_instances,
            batch_img_metas,
            batch_gt_instances_ignore,
            unmap_outputs,
        )

    def _get_targets_single(  # type: ignore[override]
        self,
        flat_anchors: Tensor,
        valid_flags: Tensor,
        num_level_anchors: list[int],
        gt_instances: InstanceData,
        img_meta: dict,
        gt_instances_ignore: InstanceData | None = None,
        unmap_outputs: bool = True,
    ) -> tuple:
        """Compute regression, classification targets for anchors in a single image.

        Args:
            flat_anchors (Tensor): Multi-level anchors of the image, which are
                concatenated into a single tensor of shape (num_anchors ,4)
            valid_flags (Tensor): Multi level valid flags of the image,
                which are concatenated into a single tensor of
                    shape (num_anchors,).
            num_level_anchors (List[int]): Number of anchors of each scale
                level.
            gt_instances (InstanceData): Ground truth of instance
                annotations. It usually includes ``bboxes`` and ``labels``
                attributes.
            img_meta (dict): Meta information for current image.
            gt_instances_ignore (InstanceData, optional): Instances
                to be ignored during training. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
            unmap_outputs (bool): Whether to map outputs back to the original
                set of anchors.

        Returns:
            tuple: N is the number of total anchors in the image.
                labels (Tensor): Labels of all anchors in the image with shape
                    (N,).
                label_weights (Tensor): Label weights of all anchor in the
                    image with shape (N,).
                bbox_targets (Tensor): BBox targets of all anchors in the
                    image with shape (N, 4).
                bbox_weights (Tensor): BBox weights of all anchors in the
                    image with shape (N, 4)
                pos_inds (Tensor): Indices of positive anchor with shape
                    (num_pos,).
                neg_inds (Tensor): Indices of negative anchor with shape
                    (num_neg,).
                sampling_result (`SamplingResult`): Sampling results.
        """
        inside_flags = anchor_inside_flags(
            flat_anchors,
            valid_flags,
            img_meta["img_shape"][:2],
            self.train_cfg["allowed_border"],
        )
        if not inside_flags.any():
            msg = (
                "There is no valid anchor inside the image boundary. Please "
                "check the image size and anchor sizes, or set "
                "``allowed_border`` to -1 to skip the condition.",
            )
            raise ValueError(msg)
        # assign gt and sample anchors
        anchors = flat_anchors[inside_flags, :]

        num_level_anchors_inside = self.get_num_level_anchors_inside(num_level_anchors, inside_flags)
        pred_instances = InstanceData(priors=anchors)
        assign_result = self.assigner.assign(  # type: ignore[call-arg]
            pred_instances,
            num_level_anchors_inside,  # type: ignore[arg-type]
            gt_instances,
            gt_instances_ignore,
        )

        sampling_result = self.sampler.sample(assign_result, pred_instances, gt_instances)

        num_valid_anchors = anchors.shape[0]
        bbox_targets = torch.zeros_like(anchors)
        bbox_weights = torch.zeros_like(anchors)
        labels = anchors.new_full((num_valid_anchors,), self.num_classes, dtype=torch.long)
        label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)

        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds
        if len(pos_inds) > 0:
            if self.reg_decoded_bbox:
                pos_bbox_targets = sampling_result.pos_gt_bboxes
            else:
                pos_bbox_targets = self.bbox_coder.encode(sampling_result.pos_priors, sampling_result.pos_gt_bboxes)

            bbox_targets[pos_inds, :] = pos_bbox_targets
            bbox_weights[pos_inds, :] = 1.0

            labels[pos_inds] = sampling_result.pos_gt_labels
            if self.train_cfg["pos_weight"] <= 0:
                label_weights[pos_inds] = 1.0
            else:
                label_weights[pos_inds] = self.train_cfg["pos_weight"]
        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0

        # map up to original set of anchors
        if unmap_outputs:
            num_total_anchors = flat_anchors.size(0)
            anchors = unmap(anchors, num_total_anchors, inside_flags)
            labels = unmap(labels, num_total_anchors, inside_flags, fill=self.num_classes)
            label_weights = unmap(label_weights, num_total_anchors, inside_flags)
            bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
            bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)

        return (anchors, labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds, sampling_result)

    def get_num_level_anchors_inside(self, num_level_anchors: list[int], inside_flags: Tensor) -> list[int]:
        """Get the number of valid anchors in every level."""
        split_inside_flags = torch.split(inside_flags, num_level_anchors)
        return [int(flags.sum()) for flags in split_inside_flags]


[docs] class ATSSHead: """ATSSHead factory for detection.""" ATSSHEAD_CFG: ClassVar[dict[str, Any]] = { "atss_mobilenetv2": { "in_channels": 64, "feat_channels": 64, }, "atss_resnext101": { "in_channels": 256, }, } def __new__( cls, model_name: str, num_classes: int, anchor_generator: BasePriorGenerator, bbox_coder: BaseBBoxCoder, train_cfg: dict, test_cfg: dict | None = None, ) -> ATSSHeadModule: """Constructor for ATSSHead.""" if model_name not in cls.ATSSHEAD_CFG: msg = f"model type '{model_name}' is not supported" raise KeyError(msg) return ATSSHeadModule( **cls.ATSSHEAD_CFG[model_name], num_classes=num_classes, anchor_generator=anchor_generator, bbox_coder=bbox_coder, train_cfg=train_cfg, # TODO (sungchul, kirill): remove test_cfg=test_cfg, # TODO (sungchul, kirill): remove )