# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""D-FINE criterion implementations. Modified from https://github.com/Peterande/D-FINE."""
from __future__ import annotations
from typing import Callable
import torch
import torch.distributed
import torch.nn.functional as f
from torch import Tensor, nn
from torchvision.ops import box_convert
from otx.algo.common.utils.assigners.hungarian_matcher import HungarianMatcher
from otx.algo.common.utils.bbox_overlaps import bbox_overlaps
from otx.algo.detection.utils.utils import dfine_bbox2distance
[docs]
class DFINECriterion(nn.Module):
"""D-Fine criterion with FGL and DDF losses.
TODO(Eugene): Consider merge with RTDETRCriterion in the next PR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
Args:
weight_dict (dict[str, int | float]): A dictionary containing the weights for different loss components.
alpha (float, optional): The alpha parameter for the loss calculation. Defaults to 0.2.
gamma (float, optional): The gamma parameter for the loss calculation. Defaults to 2.0.
num_classes (int, optional): The number of classes. Defaults to 80.
reg_max (int, optional): The maximum number of bin targets. Defaults to 32.
"""
def __init__(
self,
weight_dict: dict[str, int | float],
alpha: float = 0.2,
gamma: float = 2.0,
num_classes: int = 80,
reg_max: int = 32,
):
super().__init__()
self.num_classes = num_classes
self.matcher = HungarianMatcher(
cost_dict={
"cost_class": 2.0,
"cost_bbox": 5.0,
"cost_giou": 2.0,
},
)
self.weight_dict = weight_dict
self.alpha = alpha
self.gamma = gamma
self.reg_max = reg_max
self.num_pos, self.num_neg = 0.0, 0.0
[docs]
def loss_labels_vfl(
self,
outputs: dict[str, Tensor],
targets: list[dict[str, Tensor]],
indices: list[tuple[int, int]],
num_boxes: int,
) -> dict[str, Tensor]:
"""Varifocal Loss (VFL) for label prediction.
Args:
outputs (dict[str, Tensor]): Model outputs.
targets (List[Dict[str, Tensor]]): List of target dictionaries.
indices (List[Tuple[int, int]]): List of tuples of indices.
num_boxes (int): Number of predicted boxes.
Returns:
dict[str, Tensor]: The loss dictionary.
"""
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
ious = bbox_overlaps(
box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
)
ious = torch.diag(ious).detach()
src_logits = outputs["pred_logits"]
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
target = f.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
target_score_o[idx] = ious.to(target_score_o.dtype)
target_score = target_score_o.unsqueeze(-1) * target
pred_score = f.sigmoid(src_logits).detach()
weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
loss = f.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none")
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
return {"loss_vfl": loss}
[docs]
def loss_boxes(
self,
outputs: dict[str, Tensor],
targets: list[dict[str, Tensor]],
indices: list[tuple[int, int]],
num_boxes: int,
) -> dict[str, Tensor]:
"""Compute the losses re)L1 regression loss and the GIoU loss.
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
Args:
outputs (dict[str, Tensor]): The outputs of the model.
targets (list[dict[str, Tensor]]): The targets.
indices (list[tuple[int, int]]): The indices of the matched boxes.
num_boxes (int): The number of boxes.
Returns:
dict[str, Tensor]: The losses.
"""
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
losses = {}
loss_bbox = f.l1_loss(src_boxes, target_boxes, reduction="none")
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
bbox_overlaps(
box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
mode="giou",
),
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
[docs]
def loss_local(
self,
outputs: dict[str, Tensor],
targets: list[dict[str, Tensor]],
indices: list[tuple[int, int]],
num_boxes: int,
temperature: int = 5,
) -> dict[str, Tensor]:
"""Compute Fine-Grained Localization (FGL) Loss and Decoupled Distillation Focal (DDF) Loss.
Args:
outputs (dict[str, Tensor]): The outputs of the model.
targets (list[dict[str, Tensor]]): The targets.
indices (list[tuple[int, int]]): The indices of the matched boxes.
num_boxes (int): The number of boxes.
temperature (int, optional): Temperature for distillation. Defaults to 5.
Returns:
dict[str, Tensor]: FGL and DDF losses.
"""
losses = {}
if "pred_corners" in outputs:
idx = self._get_src_permutation_idx(indices)
target_boxes = torch.cat(
[t["boxes"][i] for t, (_, i) in zip(targets, indices)],
dim=0,
)
pred_corners = outputs["pred_corners"][idx].reshape(-1, (self.reg_max + 1))
ref_points = outputs["ref_points"][idx].detach()
with torch.no_grad():
target_corners, weight_right, weight_left = dfine_bbox2distance(
ref_points,
box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
self.reg_max,
outputs["reg_scale"],
outputs["up"],
)
ious = torch.diag(
bbox_overlaps(
box_convert(outputs["pred_boxes"][idx], in_fmt="cxcywh", out_fmt="xyxy"),
box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
),
)
weight_targets = ious.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
losses["loss_fgl"] = DFINECriterion.fgl_loss(
pred_corners,
target_corners,
weight_right,
weight_left,
weight_targets,
avg_factor=num_boxes,
)
# Compute Decoupled Distillation Focal (DDF) Loss
if "teacher_corners" in outputs and outputs["teacher_corners"] is not None:
pred_corners = outputs["pred_corners"].reshape(-1, (self.reg_max + 1))
target_corners = outputs["teacher_corners"].reshape(-1, (self.reg_max + 1))
if torch.equal(pred_corners, target_corners):
losses["loss_ddf"] = pred_corners.sum() * 0
else:
weight_targets_local = outputs["teacher_logits"].sigmoid().max(dim=-1)[0]
mask = torch.zeros_like(weight_targets_local, dtype=torch.bool)
mask[idx] = True
mask = mask.unsqueeze(-1).repeat(1, 1, 4).reshape(-1)
weight_targets_local[idx] = ious.reshape_as(weight_targets_local[idx]).to(
weight_targets_local.dtype,
)
weight_targets_local = weight_targets_local.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
loss_match_local = (
weight_targets_local
* (temperature**2)
* (
nn.KLDivLoss(reduction="none")(
f.log_softmax(pred_corners / temperature, dim=1),
f.softmax(target_corners.detach() / temperature, dim=1),
)
).sum(-1)
)
if "is_dn" not in outputs:
batch_scale = 8 / outputs["pred_boxes"].shape[0] # Avoid the influence of batch size per GPU
self.num_pos, self.num_neg = (
(mask.sum() * batch_scale) ** 0.5,
((~mask).sum() * batch_scale) ** 0.5,
)
loss_match_local1 = loss_match_local[mask].mean() if mask.any() else 0
loss_match_local2 = loss_match_local[~mask].mean() if (~mask).any() else 0
losses["loss_ddf"] = (loss_match_local1 * self.num_pos + loss_match_local2 * self.num_neg) / (
self.num_pos + self.num_neg
)
return losses
def _get_src_permutation_idx(
self,
indices: list[tuple[Tensor, Tensor]],
) -> tuple[Tensor, Tensor]:
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_go_indices(
self,
indices: list[tuple[Tensor, Tensor]],
indices_aux_list: list[list[tuple[Tensor, Tensor]]],
) -> list[Tensor]:
"""Get a matching union set across all decoder layers.
Args:
indices: matching indices of the last decoder layer
indices_aux_list: matching indices of all decoder layers
"""
results = []
for indices_aux in indices_aux_list:
indices = [
(torch.cat([idx1[0], idx2[0]]), torch.cat([idx1[1], idx2[1]]))
for idx1, idx2 in zip(indices.copy(), indices_aux.copy())
]
for ind in [torch.cat([idx[0][:, None], idx[1][:, None]], 1) for idx in indices]:
unique, counts = torch.unique(ind, return_counts=True, dim=0)
count_sort_indices = torch.argsort(counts, descending=True)
unique_sorted = unique[count_sort_indices]
column_to_row = {}
for idx in unique_sorted:
row_idx, col_idx = idx[0].item(), idx[1].item()
if row_idx not in column_to_row:
column_to_row[row_idx] = col_idx
final_rows = torch.tensor(list(column_to_row.keys()), device=ind.device)
final_cols = torch.tensor(list(column_to_row.values()), device=ind.device)
results.append((final_rows.long(), final_cols.long()))
return results
@property
def _available_losses(self) -> tuple[Callable]:
return (self.loss_boxes, self.loss_labels_vfl, self.loss_local) # type: ignore[return-value]
[docs]
def forward(
self,
outputs: dict[str, Tensor],
targets: list[dict[str, Tensor]],
) -> dict[str, Tensor]:
"""This performs the loss computation.
Args:
outputs (dict[str, torch.Tensor]): dict of tensors, see the output
specification of the model for the format
targets (list[dict[str, torch.Tensor]]): list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
Returns:
dict[str, torch.Tensor]: dict of losses
"""
outputs_without_aux = {k: v for k, v in outputs.items() if "aux" not in k}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Get the matching union set across all decoder layers.
indices_aux_list, cached_indices, cached_indices_enc = [], [], []
for aux_outputs in outputs["aux_outputs"] + [outputs["pre_outputs"]]:
indices_aux = self.matcher(aux_outputs, targets)
cached_indices.append(indices_aux)
indices_aux_list.append(indices_aux)
for aux_outputs in outputs["enc_aux_outputs"]:
indices_enc = self.matcher(aux_outputs, targets)
cached_indices_enc.append(indices_enc)
indices_aux_list.append(indices_enc)
indices_go = self._get_go_indices(indices, indices_aux_list)
num_boxes_go = sum(len(x[0]) for x in indices_go)
num_boxes_go = torch.as_tensor(
[num_boxes_go],
dtype=torch.float,
device=next(iter(outputs.values())).device,
)
num_boxes_go = torch.clamp(num_boxes_go, min=1).item()
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
num_boxes = torch.clamp(num_boxes, min=1).item()
# Compute all the requested losses
losses = {}
for loss in self._available_losses:
indices_in = indices_go if loss in [self.loss_boxes, self.loss_local] else indices
num_boxes_in = num_boxes_go if loss in [self.loss_boxes, self.loss_local] else num_boxes
l_dict = loss(outputs, targets, indices_in, num_boxes_in)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
losses.update(l_dict)
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if "aux_outputs" in outputs:
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
aux_outputs["up"], aux_outputs["reg_scale"] = outputs["up"], outputs["reg_scale"]
for loss in self._available_losses:
if loss in [self.loss_boxes, self.loss_local]:
indices_in = indices_go
num_boxes_in = num_boxes_go
else:
indices_in = cached_indices[i]
num_boxes_in = num_boxes
l_dict = loss(aux_outputs, targets, indices_in, num_boxes_in)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
# In case of auxiliary traditional head output at first decoder layer.
if "pre_outputs" in outputs:
aux_outputs = outputs["pre_outputs"]
for loss in self._available_losses:
if loss in [self.loss_boxes, self.loss_local]:
indices_in = indices_go
num_boxes_in = num_boxes_go
else:
indices_in = cached_indices[-1]
num_boxes_in = num_boxes
l_dict = loss(aux_outputs, targets, indices_in, num_boxes_in)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
l_dict = {k + "_pre": v for k, v in l_dict.items()}
losses.update(l_dict)
# In case of encoder auxiliary losses.
if "enc_aux_outputs" in outputs:
enc_targets = targets
for i, aux_outputs in enumerate(outputs["enc_aux_outputs"]):
for loss in self._available_losses:
if loss == self.loss_boxes:
indices_in = indices_go
num_boxes_in = num_boxes_go
else:
indices_in = cached_indices_enc[i]
num_boxes_in = num_boxes
l_dict = loss(aux_outputs, enc_targets, indices_in, num_boxes_in)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
l_dict = {k + f"_enc_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
# In case of cdn auxiliary losses. For dfine
if "dn_outputs" in outputs:
indices_dn = self.get_cdn_matched_indices(outputs["dn_meta"], targets)
dn_num_boxes = num_boxes * outputs["dn_meta"]["dn_num_group"]
dn_num_boxes = dn_num_boxes if dn_num_boxes > 0 else 1
for i, aux_outputs in enumerate(outputs["dn_outputs"]):
aux_outputs["is_dn"] = True
aux_outputs["up"], aux_outputs["reg_scale"] = outputs["up"], outputs["reg_scale"]
for loss in self._available_losses:
l_dict = loss(aux_outputs, targets, indices_dn, dn_num_boxes)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
# In case of auxiliary traditional head output at first decoder layer.
if "dn_pre_outputs" in outputs:
aux_outputs = outputs["dn_pre_outputs"]
for loss in self._available_losses:
l_dict = loss(aux_outputs, targets, indices_dn, dn_num_boxes)
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
l_dict = {k + "_dn_pre": v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
[docs]
@staticmethod
def get_cdn_matched_indices(
dn_meta: dict[str, list[Tensor]],
targets: list[dict[str, Tensor]],
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""get_cdn_matched_indices.
Args:
dn_meta (dict[str, list[torch.Tensor]]): meta data for cdn
targets (list[dict[str, torch.Tensor]]): targets
"""
dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
num_gts = [len(t["labels"]) for t in targets]
device = targets[0]["labels"].device
dn_match_indices = []
for i, num_gt in enumerate(num_gts):
if num_gt > 0:
gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
gt_idx = gt_idx.tile(dn_num_group)
if len(dn_positive_idx[i]) != len(gt_idx):
msg = "The number of positive indices should be equal to the number of ground truths."
raise ValueError(msg)
dn_match_indices.append((dn_positive_idx[i], gt_idx))
else:
dn_match_indices.append(
(
torch.zeros(0, dtype=torch.int64, device=device),
torch.zeros(0, dtype=torch.int64, device=device),
),
)
return dn_match_indices
[docs]
@staticmethod
def fgl_loss(
preds: Tensor,
targets: Tensor,
weight_right: Tensor,
weight_left: Tensor,
iou_weight: Tensor | None = None,
reduction: str = "sum",
avg_factor: float | None = None,
) -> Tensor:
"""Fine-Grained Localization (FGL) Loss.
Args:
preds (Tensor): predicted distances
targets (Tensor): target distances
weight_right (Tensor): weight for right distance
weight_left (Tensor): weight for left distance
iou_weight (Tensor, optional): IoU weight. Defaults to None.
reduction (str, optional): reduction method. Defaults to "sum".
avg_factor (float, optional): average factor. Defaults to None.
Returns:
Tensor: FGL loss
"""
dis_left = targets.long()
dis_right = dis_left + 1
loss_left = f.cross_entropy(
preds,
dis_left,
reduction="none",
) * weight_left.reshape(-1)
loss_right = f.cross_entropy(
preds,
dis_right,
reduction="none",
) * weight_right.reshape(-1)
loss = loss_left + loss_right
if iou_weight is not None:
iou_weight = iou_weight.float()
loss = loss * iou_weight
if avg_factor is not None:
loss = loss.sum() / avg_factor
elif reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss