# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Criterion module for YOLOv7 and v9.
Reference : https://github.com/WongKinYiu/YOLO
"""
from __future__ import annotations
import math
from typing import Literal
import torch
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss
from otx.algo.detection.utils.utils import Vec2Box
def calculate_iou(bbox1: Tensor, bbox2: Tensor, metrics: Literal["iou", "diou", "ciou"] = "iou") -> Tensor:
"""Calculate the Intersection over Union (IoU) between two sets of bounding boxes.
Args:
bbox1 (Tensor): The first set of bounding boxes.
bbox2 (Tensor): The second set of bounding boxes.
metrics (Literal["iou", "diou", "ciou"], optional): The metrics to calculate. Defaults to "iou".
Returns:
Tensor: The IoU between the two sets of bounding boxes.
"""
eps = 1e-9
dtype = bbox1.dtype
bbox1 = bbox1.to(torch.float32)
bbox2 = bbox2.to(torch.float32)
# Expand dimensions if necessary
if bbox1.ndim == 2 and bbox2.ndim == 2:
bbox1 = bbox1.unsqueeze(1) # (Ax4) -> (Ax1x4)
bbox2 = bbox2.unsqueeze(0) # (Bx4) -> (1xBx4)
elif bbox1.ndim == 3 and bbox2.ndim == 3:
bbox1 = bbox1.unsqueeze(2) # (BZxAx4) -> (BZxAx1x4)
bbox2 = bbox2.unsqueeze(1) # (BZxBx4) -> (BZx1xBx4)
# Calculate intersection coordinates
xmin_inter = torch.max(bbox1[..., 0], bbox2[..., 0])
ymin_inter = torch.max(bbox1[..., 1], bbox2[..., 1])
xmax_inter = torch.min(bbox1[..., 2], bbox2[..., 2])
ymax_inter = torch.min(bbox1[..., 3], bbox2[..., 3])
# Calculate intersection area
intersection_area = torch.clamp(xmax_inter - xmin_inter, min=0) * torch.clamp(ymax_inter - ymin_inter, min=0)
# Calculate area of each bbox
area_bbox1 = (bbox1[..., 2] - bbox1[..., 0]) * (bbox1[..., 3] - bbox1[..., 1])
area_bbox2 = (bbox2[..., 2] - bbox2[..., 0]) * (bbox2[..., 3] - bbox2[..., 1])
# Calculate union area
union_area = area_bbox1 + area_bbox2 - intersection_area
# Calculate IoU
iou = intersection_area / (union_area + eps)
if metrics == "iou":
return iou.to(dtype)
# Calculate centroid distance
cx1 = (bbox1[..., 2] + bbox1[..., 0]) / 2
cy1 = (bbox1[..., 3] + bbox1[..., 1]) / 2
cx2 = (bbox2[..., 2] + bbox2[..., 0]) / 2
cy2 = (bbox2[..., 3] + bbox2[..., 1]) / 2
cent_dis = (cx1 - cx2) ** 2 + (cy1 - cy2) ** 2
# Calculate diagonal length of the smallest enclosing box
c_x = torch.max(bbox1[..., 2], bbox2[..., 2]) - torch.min(bbox1[..., 0], bbox2[..., 0])
c_y = torch.max(bbox1[..., 3], bbox2[..., 3]) - torch.min(bbox1[..., 1], bbox2[..., 1])
diag_dis = c_x**2 + c_y**2 + eps
diou = iou - (cent_dis / diag_dis)
if metrics == "diou":
return diou.to(dtype)
# Compute aspect ratio penalty term
arctan = torch.atan((bbox1[..., 2] - bbox1[..., 0]) / (bbox1[..., 3] - bbox1[..., 1] + eps)) - torch.atan(
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + eps),
)
v = (4 / (math.pi**2)) * (arctan**2)
alpha = v / (v - iou + 1 + eps)
# Compute CIoU
ciou = diou - alpha * v
return ciou.to(dtype)
class BCELoss(nn.Module):
"""Binary Cross Entropy Loss.
TODO (author): Refactor the device, should be assign by config
TODO (author): origin v9 assign pos_weight == 1?
TODO (sungchul): check if it can be replaced with otx.algo.common.losses.cross_entropy_loss.CrossEntropyLoss
"""
def __init__(self) -> None:
super().__init__()
self.bce = BCEWithLogitsLoss(reduction="none")
def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Tensor:
"""Calculate the BCE loss for the classification."""
return self.bce(predicts_cls, targets_cls).sum() / cls_norm
class BoxLoss(nn.Module):
"""Box Loss.
TODO (sungchul): check if it can be replaced with otx.algo.common.losses.iou_loss.IoULoss
"""
def forward(
self,
predicts_bbox: Tensor,
targets_bbox: Tensor,
valid_masks: Tensor,
box_norm: Tensor,
cls_norm: Tensor,
) -> Tensor:
"""Calculate the IoU loss for the bounding box.
Args:
predicts_bbox (Tensor): The predicted bounding box.
targets_bbox (Tensor): The target bounding box.
valid_masks (Tensor): The mask for valid bounding box.
box_norm (Tensor): The normalization factor for the bounding box.
cls_norm (Tensor): The normalization factor for the class.
Returns:
Tensor: The IoU loss for the bounding box.
"""
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
picked_targets = targets_bbox[valid_bbox].view(-1, 4)
iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
loss_iou = 1.0 - iou
return (loss_iou * box_norm).sum() / cls_norm
class DFLoss(nn.Module):
"""Distribution Focal Loss (DFL).
Args:
vec2box (Vec2Box): The Vec2Box object.
reg_max (int, optional): Maximum number of anchor regions. Defaults to 16.
"""
def __init__(self, vec2box: Vec2Box, reg_max: int = 16) -> None:
super().__init__()
self.anchors_norm = (vec2box.anchor_grid / vec2box.scaler[:, None])[None]
self.reg_max = reg_max
def forward(
self,
predicts_anc: Tensor,
targets_bbox: Tensor,
valid_masks: Tensor,
box_norm: Tensor,
cls_norm: Tensor,
) -> Tensor:
"""Calculate the DFLoss for the bounding box.
Args:
predicts_anc (Tensor): The predicted anchor.
targets_bbox (Tensor): The target bounding box.
valid_masks (Tensor): The mask for valid bounding box.
box_norm (Tensor): The normalization factor for the bounding box.
cls_norm (Tensor): The normalization factor for the class.
Returns:
Tensor: The DFLoss for the bounding box.
"""
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
targets_dist = torch.cat(((self.anchors_norm - bbox_lt), (bbox_rb - self.anchors_norm)), -1).clamp(
0,
self.reg_max - 1.01,
)
picked_targets = targets_dist[valid_bbox].view(-1)
picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
weight_left, weight_right = label_right - picked_targets, picked_targets - label_left
loss_left = nn.functional.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
loss_right = nn.functional.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
loss_dfl = loss_left * weight_left + loss_right * weight_right
loss_dfl = loss_dfl.view(-1, 4).mean(-1)
return (loss_dfl * box_norm).sum() / cls_norm
class BoxMatcher:
"""Box Matcher.
Args:
class_num (int): The number of classes.
anchors (Tensor): The anchor tensor.
iou (str, optional): The IoU method. Defaults to "CIoU".
topk (int, optional): The number of top scores to retain per anchor. Defaults to 10.
factor (dict[str, float] | None, optional): The factor for IoU and class. Defaults to {"iou": 6.0, "cls": 0.5}.
"""
def __init__(
self,
class_num: int,
anchors: Tensor,
iou: Literal["iou", "diou", "ciou"] = "ciou",
topk: int = 10,
factor: dict[str, float] | None = None,
) -> None:
self.class_num = class_num
self.anchors = anchors
self.iou = iou
self.topk = topk
self.factor = factor or {"iou": 6.0, "cls": 0.5}
def get_valid_matrix(self, target_bbox: Tensor) -> Tensor:
"""Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
Args:
target_bbox (Tensor): The bounding box of each targets with (batch, targets, 4).
Returns:
Tensor: A boolean tensor indicates if target bounding box overlaps with anchors
with (batch, targets, anchors).
"""
xmin, ymin, xmax, ymax = target_bbox[:, :, None].unbind(3)
anchors = self.anchors[None, None] # add a axis at first, second dimension
anchors_x, anchors_y = anchors.unbind(dim=3)
target_in_x = (xmin < anchors_x) & (anchors_x < xmax)
target_in_y = (ymin < anchors_y) & (anchors_y < ymax)
return target_in_x & target_in_y
def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
"""Get the (predicted class' probabilities) corresponding to the target classes across all anchors.
Args:
predict_cls (Tensor): The predicted probabilities for each class across each anchor
with (batch, anchors, class).
target_cls (Tensor): The class index for each target with (batch, targets, 1).
Returns:
Tensor: The probabilities from `pred_cls` corresponding to the class indices
specified in `target_cls` with (batch, targets, anchors).
"""
predict_cls = predict_cls.transpose(1, 2)
target_cls = target_cls.expand(-1, -1, predict_cls.size(2))
return torch.gather(predict_cls, 1, target_cls)
def get_iou_matrix(self, predict_bbox: Tensor, target_bbox: Tensor) -> Tensor:
"""Get the IoU between each target bounding box and each predicted bounding box.
Args:
predict_bbox (Tensor): Bounding box with [x1, y1, x2, y2] with (batch, predicts, 4).
target_bbox (Tensor): Bounding box with [x1, y1, x2, y2] with (batch, targets, 4).
Returns:
Tensor: The IoU scores between each target and predicted with (batch, targets, predicts).
"""
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> tuple[Tensor, Tensor]:
"""Filter the top-k suitability of targets for each anchor.
Args:
target_matrix (Tensor): The suitability for each targets-anchors with (batch, targets, anchors).
topk (int, optional): Number of top scores to retain per anchor.
Returns:
tuple[Tensor, Tensor]: The top-k suitability for each targets-anchors with (batch, targets, anchors)
and a boolean mask indicating the top-k scores' positions with (batch, targets, anchors).
"""
values, indices = target_matrix.topk(topk, dim=-1)
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
topk_targets.scatter_(dim=-1, index=indices, src=values)
topk_masks = topk_targets > 0
return topk_targets, topk_masks
def filter_duplicates(self, target_matrix: Tensor) -> Tensor:
"""Filter the maximum suitability target index of each anchor.
Args:
target_matrix (Tensor): The suitability for each targets-anchors with (batch, targets, anchors).
Returns:
unique_indices (Tensor): The index of the best targets for each anchors with (batch, anchors, 1).
"""
# TODO (author): add a assert for no target on the image
unique_indices = target_matrix.argmax(dim=1)
return unique_indices[..., None]
def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
"""Assign the best suitable ground truth box for each predicted anchor.
1. For each anchor prediction, find the highest suitability targets
2. Select the targets
3. Normalize the class probabilities of targets.
Args:
target (Tensor): The target tensor with class and bounding box with (batch, targets, (class + 4)).
predict (tuple[Tensor, Tensor]): The predicted class and bounding box.
Returns:
tuple[Tensor, Tensor]: The aligned target tensors with (batch, targets, (class + 4)) and (batch, targets).
"""
predict_cls, predict_bbox = predict
# return if target has no gt information.
n_targets = target.shape[1]
if n_targets == 0:
device = predict_bbox.device
align_cls = torch.zeros_like(predict_cls, device=device)
align_bbox = torch.zeros_like(predict_bbox, device=device)
valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
target_cls = target_cls.long().clamp(0)
# get valid matrix (each gt appear in which anchor grid)
grid_mask = self.get_valid_matrix(target_bbox)
# get iou matrix (iou with each gt bbox and each predict anchor)
iou_mat = self.get_iou_matrix(predict_bbox, target_bbox)
# get cls matrix (cls prob with each gt class and each predict class)
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
# choose topk
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
# delete one anchor pred assign to multiple gts
unique_indices = self.filter_duplicates(topk_targets)
# TODO (author): do we need grid_mask? Filter the valid ground truth
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
align_cls = nn.functional.one_hot(align_cls, self.class_num)
# normalize class distribution
max_target = target_matrix.amax(dim=-1, keepdim=True)
max_iou = iou_mat.amax(dim=-1, keepdim=True)
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask
[docs]
class YOLOv9Criterion(nn.Module):
"""YOLOv9 criterion module.
This module calculates the loss for YOLOv9 object detection model.
Args:
num_classes (int): The number of classes.
vec2box (Vec2Box): The Vec2Box object.
loss_cls (nn.Module | None): The classification loss module. Defaults to None.
loss_dfl (nn.Module | None): The DFLoss loss module. Defaults to None.
loss_iou (nn.Module | None): The IoULoss loss module. Defaults to None.
reg_max (int, optional): Maximum number of anchor regions. Defaults to 16.
cls_rate (float, optional): The classification loss rate. Defaults to 1.5.
dfl_rate (float, optional): The DFLoss loss rate. Defaults to 7.5.
iou_rate (float, optional): The IoU loss rate. Defaults to 0.5.
aux_rate (float, optional): The auxiliary loss rate. Defaults to 0.25.
"""
def __init__(
self,
num_classes: int,
vec2box: Vec2Box,
loss_cls: nn.Module | None = None,
loss_dfl: nn.Module | None = None,
loss_iou: nn.Module | None = None,
reg_max: int = 16,
cls_rate: float = 0.5,
dfl_rate: float = 1.5,
iou_rate: float = 7.5,
aux_rate: float = 0.25,
) -> None:
super().__init__()
self.num_classes = num_classes
self.loss_cls = loss_cls or BCELoss()
self.loss_dfl = loss_dfl or DFLoss(vec2box, reg_max)
self.loss_iou = loss_iou or BoxLoss()
self.vec2box = vec2box
self.matcher = BoxMatcher(num_classes, vec2box.anchor_grid)
self.cls_rate = cls_rate
self.dfl_rate = dfl_rate
self.iou_rate = iou_rate
self.aux_rate = aux_rate
[docs]
def forward(
self,
main_preds: tuple[Tensor, Tensor, Tensor],
targets: Tensor,
aux_preds: tuple[Tensor, Tensor, Tensor] | None = None,
) -> dict[str, Tensor] | None:
"""Forward pass of the YOLOv9 criterion module.
Args:
main_preds (tuple[Tensor, Tensor, Tensor]): The main predictions.
targets (Tensor): The learning target of the prediction.
aux_preds (tuple[Tensor, Tensor, Tensor], optional): The auxiliary predictions. Defaults to None.
Returns:
dict[str, Tensor]: The loss dictionary.
"""
if targets.shape[1] == 0:
# TODO (sungchul): should this step be done here?
return None
main_preds = self.vec2box(main_preds)
main_iou, main_dfl, main_cls = self._forward(main_preds, targets)
aux_iou, aux_dfl, aux_cls = 0.0, 0.0, 0.0
if aux_preds:
aux_preds = self.vec2box(aux_preds)
aux_iou, aux_dfl, aux_cls = self._forward(aux_preds, targets)
loss_dict = {
"loss_cls": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
"loss_df": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
"loss_iou": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
}
loss_dict.update(total_loss=sum(list(loss_dict.values())) / len(loss_dict))
return loss_dict
def _forward(self, predicts: tuple[Tensor, Tensor, Tensor], targets: Tensor) -> tuple[Tensor, Tensor, Tensor]:
predicts_cls, predicts_anc, predicts_box = predicts
# For each predicted targets, assign a best suitable ground truth box.
align_targets, valid_masks = self.matcher(targets, (predicts_cls.detach(), predicts_box.detach()))
targets_cls, targets_bbox = self.separate_anchor(align_targets)
predicts_box = predicts_box / self.vec2box.scaler[None, :, None]
cls_norm = targets_cls.sum()
box_norm = targets_cls.sum(-1)[valid_masks]
## -- CLS -- ##
loss_cls = self.loss_cls(predicts_cls, targets_cls, cls_norm)
## -- IOU -- ##
loss_iou = self.loss_iou(predicts_box, targets_bbox, valid_masks, box_norm, cls_norm)
## -- DFL -- ##
loss_dfl = self.loss_dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
return loss_iou, loss_dfl, loss_cls
[docs]
def separate_anchor(self, anchors: Tensor) -> tuple[Tensor, Tensor]:
"""Separate anchor and bounding box.
Args:
anchors (Tensor): The anchor tensor.
"""
anchors_cls, anchors_box = torch.split(anchors, (self.num_classes, 4), dim=-1)
anchors_box = anchors_box / self.vec2box.scaler[None, :, None]
return anchors_cls, anchors_box