Source code for otx.algo.detection.losses.atss_loss

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
#
"""ATSS criterion."""

from __future__ import annotations

import torch
from torch import Tensor, nn

from otx.algo.common.losses import CrossEntropyLoss, CrossSigmoidFocalLoss, QualityFocalLoss
from otx.algo.common.utils.bbox_overlaps import bbox_overlaps
from otx.algo.common.utils.utils import multi_apply, reduce_mean


[docs] class ATSSCriterion(nn.Module): """ATSSCriterion is a loss criterion used in the Adaptive Training Sample Selection (ATSS) algorithm. Args: num_classes (int): The number of object classes. bbox_coder (nn.Module): The module used for encoding and decoding bounding box coordinates. loss_cls (nn.Module): The module used for calculating the classification loss. loss_bbox (nn.Module): The module used for calculating the bounding box regression loss. loss_centerness (nn.Module | None, optional): The module used for calculating the centerness loss. Defaults to None. use_qfl (bool, optional): Whether to use the Quality Focal Loss (QFL). Defaults to ``CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0)``. reg_decoded_bbox (bool, optional): Whether to use the decoded bounding box coordinates for regression loss calculation. Defaults to True. bg_loss_weight (float, optional): The weight for the background loss. Defaults to -1.0. """ def __init__( self, num_classes: int, bbox_coder: nn.Module, loss_cls: nn.Module, loss_bbox: nn.Module, loss_centerness: nn.Module | None = None, use_qfl: bool = False, qfl_cfg: dict | None = None, reg_decoded_bbox: bool = True, bg_loss_weight: float = -1.0, ) -> None: super().__init__() self.num_classes = num_classes self.bbox_coder = bbox_coder self.use_qfl = use_qfl self.reg_decoded_bbox = reg_decoded_bbox self.bg_loss_weight = bg_loss_weight self.loss_bbox = loss_bbox self.loss_centerness = loss_centerness or CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0) if use_qfl: loss_cls = qfl_cfg or QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0) self.loss_cls = loss_cls self.use_sigmoid_cls = loss_cls.use_sigmoid if self.use_sigmoid_cls: self.cls_out_channels = num_classes else: self.cls_out_channels = num_classes + 1 if self.cls_out_channels <= 0: msg = f"num_classes={num_classes} is too small" raise ValueError(msg)
[docs] def forward( self, anchors: Tensor, cls_score: Tensor, bbox_pred: Tensor, centerness: Tensor, labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, valid_label_mask: Tensor, avg_factor: float, ) -> dict[str, Tensor]: """Compute loss of a single scale level. Args: anchors (Tensor): Box reference for scale levels with shape (N, num_total_anchors, 4). cls_score (Tensor): Box scores for scale levels have shape (N, num_anchors * num_classes, H, W). bbox_pred (Tensor): Box energies / deltas for scale levels with shape (N, num_anchors * 4, H, W). centerness(Tensor): Centerness scores for each scale level. labels (Tensor): Labels of anchors with shape (N, num_total_anchors). label_weights (Tensor): Label weights of anchors with shape (N, num_total_anchors) bbox_targets (Tensor): BBox regression targets of anchors with shape (N, num_total_anchors, 4). valid_label_mask (Tensor): Label mask for consideration of ignored label with shape (N, num_total_anchors, 1). avg_factor (float): Average factor that is used to average the loss. When using sampling method, avg_factor is usually the sum of positive and negative priors. When using `PseudoSampler`, `avg_factor` is usually equal to the number of positive priors. Returns: dict[str, Tensor]: A dictionary of loss components. """ losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = multi_apply( self._forward, anchors, cls_score, bbox_pred, centerness, labels, label_weights, bbox_targets, valid_label_mask, avg_factor=avg_factor, ) bbox_avg_factor = sum(bbox_avg_factor) bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() losses_bbox = [loss_bbox / bbox_avg_factor for loss_bbox in losses_bbox] return {"loss_cls": losses_cls, "loss_bbox": losses_bbox, "loss_centerness": loss_centerness}
def _forward( self, anchors: Tensor, cls_score: Tensor, bbox_pred: Tensor, centerness: Tensor, labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, valid_label_mask: Tensor, avg_factor: float, ) -> tuple: """Compute loss of a single scale level. Args: anchors (Tensor): Box reference for each scale level with shape (N, num_total_anchors, 4). cls_score (Tensor): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W). bbox_pred (Tensor): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W). centerness(Tensor): Centerness scores for each scale level. labels (Tensor): Labels of each anchors with shape (N, num_total_anchors). label_weights (Tensor): Label weights of each anchor with shape (N, num_total_anchors) bbox_targets (Tensor): BBox regression targets of each anchor with shape (N, num_total_anchors, 4). valid_label_mask (Tensor): Label mask for consideration of ignored label with shape (N, num_total_anchors, 1). avg_factor (float): Average factor that is used to average the loss. When using sampling method, avg_factor is usually the sum of positive and negative priors. When using `PseudoSampler`, `avg_factor` is usually equal to the number of positive priors. Returns: tuple[Tensor]: A tuple of loss components. """ anchors = anchors.reshape(-1, 4) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous() bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) centerness = centerness.permute(0, 2, 3, 1).reshape(-1) bbox_targets = bbox_targets.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) valid_label_mask = valid_label_mask.reshape(-1, self.cls_out_channels) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes pos_inds = self._get_pos_inds(labels) if self.use_qfl: quality = label_weights.new_zeros(labels.shape) if len(pos_inds) > 0: pos_bbox_targets = bbox_targets[pos_inds] pos_bbox_pred = bbox_pred[pos_inds] pos_anchors = anchors[pos_inds] pos_centerness = centerness[pos_inds] centerness_targets = self.centerness_target(pos_anchors, pos_bbox_targets) if self.reg_decoded_bbox: pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred) if self.use_qfl: quality[pos_inds] = bbox_overlaps(pos_bbox_pred.detach(), pos_bbox_targets, is_aligned=True).clamp( min=1e-6, ) # regression loss loss_bbox = self._get_loss_bbox(pos_bbox_targets, pos_bbox_pred, centerness_targets) # centerness loss loss_centerness = self._get_loss_centerness(avg_factor, pos_centerness, centerness_targets) else: loss_bbox = bbox_pred.sum() * 0 loss_centerness = centerness.sum() * 0 centerness_targets = bbox_targets.new_tensor(0.0) # Re-weigting BG loss if self.bg_loss_weight >= 0.0: neg_indices = labels == self.num_classes label_weights[neg_indices] = self.bg_loss_weight if self.use_qfl: labels = (labels, quality) # For quality focal loss arg spec # classification loss loss_cls = self._get_loss_cls(cls_score, labels, label_weights, valid_label_mask, avg_factor) return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
[docs] def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor: """Calculate the centerness between anchors and gts. Only calculate pos centerness targets, otherwise there may be nan. Args: anchors (Tensor): Anchors with shape (N, 4), "xyxy" format. gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format. Returns: Tensor: Centerness between anchors and gts. """ anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 l_ = anchors_cx - gts[:, 0] t_ = anchors_cy - gts[:, 1] r_ = gts[:, 2] - anchors_cx b_ = gts[:, 3] - anchors_cy left_right = torch.stack([l_, r_], dim=1) top_bottom = torch.stack([t_, b_], dim=1) return torch.sqrt( (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]), )
def _get_pos_inds(self, labels: Tensor) -> Tensor: bg_class_ind = self.num_classes return ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) def _get_loss_bbox( self, pos_bbox_targets: Tensor, pos_bbox_pred: Tensor, centerness_targets: Tensor, ) -> Tensor: return self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0) def _get_loss_centerness( self, avg_factor: Tensor, pos_centerness: Tensor, centerness_targets: Tensor, ) -> Tensor: return self.loss_centerness(pos_centerness, centerness_targets, avg_factor=avg_factor) def _get_loss_cls( self, cls_score: Tensor, labels: Tensor, label_weights: Tensor, valid_label_mask: Tensor, avg_factor: Tensor, ) -> Tensor: if isinstance(self.loss_cls, CrossSigmoidFocalLoss): loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor, valid_label_mask=valid_label_mask, ) else: loss_cls = self.loss_cls(cls_score, labels, label_weights, avg_factor=avg_factor) return loss_cls