Source code for otx.algo.object_detection_3d.losses.monodetr_loss

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""main loss for MonoDETR model."""
from __future__ import annotations

from typing import TYPE_CHECKING, Callable

import torch
from torch import nn
from torch.nn import functional
from torchvision.ops import box_convert

from otx.algo.common.losses.focal_loss import py_sigmoid_focal_loss
from otx.algo.common.losses.iou_loss import giou_loss
from otx.algo.object_detection_3d.matchers.matcher_3d import HungarianMatcher3D
from otx.algo.object_detection_3d.utils.utils import box_cxcylrtb_to_xyxy

from .ddn_loss import DDNLoss

if TYPE_CHECKING:
    from torch import Tensor


[docs] class MonoDETRCriterion(nn.Module): """This class computes the loss for MonoDETR.""" def __init__(self, num_classes: int, weight_dict: dict, focal_alpha: float, group_num: int = 11) -> None: """MonoDETRCriterion. Args: num_classes (int): number of object categories, omitting the special no-object category. weight_dict (dict): dict containing as key the names of the losses and as values their relative weight. focal_alpha (float): alpha in Focal Loss. group_num (int): number of groups for data parallelism. """ super().__init__() self.num_classes = num_classes self.matcher = HungarianMatcher3D(cost_class=2, cost_3dcenter=10, cost_bbox=5, cost_giou=2) self.weight_dict = weight_dict for name in self.loss_map: if name not in self.weight_dict: self.weight_dict[name] = 1 self.focal_alpha = focal_alpha self.ddn_loss = DDNLoss() # for depth map self.group_num = group_num
[docs] def loss_labels(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Classification loss. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch. """ src_logits = outputs["scores"] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o.squeeze().long() target_classes_onehot = torch.zeros( [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device, ) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:, :, :-1] loss_ce = py_sigmoid_focal_loss( pred=src_logits, target=target_classes_onehot, avg_factor=num_boxes, alpha=self.focal_alpha, reduction="mean", ) return {"loss_ce": loss_ce}
[docs] def loss_3dcenter(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Compute the loss for the 3D center prediction. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch. """ idx = self._get_src_permutation_idx(indices) src_3dcenter = outputs["boxes_3d"][:, :, 0:2][idx] target_3dcenter = torch.cat([t["boxes_3d"][:, 0:2][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_3dcenter = functional.l1_loss(src_3dcenter, target_3dcenter, reduction="none") return {"loss_center": loss_3dcenter.sum() / num_boxes}
[docs] def loss_boxes(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Compute l1 loss. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch. """ idx = self._get_src_permutation_idx(indices) src_2dboxes = outputs["boxes_3d"][:, :, 2:6][idx] target_2dboxes = torch.cat([t["boxes_3d"][:, 2:6][i] for t, (_, i) in zip(targets, indices)], dim=0) # l1 loss_bbox = functional.l1_loss(src_2dboxes, target_2dboxes, reduction="none") return {"loss_bbox": loss_bbox.sum() / num_boxes}
[docs] def loss_giou(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Compute the GIoU loss. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch. """ # giou idx = self._get_src_permutation_idx(indices) src_boxes = outputs["boxes_3d"][idx] target_boxes = torch.cat([t["boxes_3d"][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_giou = giou_loss(box_cxcylrtb_to_xyxy(src_boxes), box_cxcylrtb_to_xyxy(target_boxes)) return {"loss_giou": loss_giou}
[docs] def loss_depths(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Compute the loss for the depth prediction. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch """ idx = self._get_src_permutation_idx(indices) src_depths = outputs["depth"][idx] target_depths = torch.cat([t["depth"][i] for t, (_, i) in zip(targets, indices)], dim=0).squeeze() depth_input, depth_log_variance = src_depths[:, 0], src_depths[:, 1] depth_loss = 1.4142 * torch.exp(-depth_log_variance) * torch.abs(depth_input - target_depths) + torch.abs( depth_log_variance, ) return {"loss_depth": depth_loss.sum() / num_boxes}
[docs] def loss_dims(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Compute the loss for the dimension prediction. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch. """ idx = self._get_src_permutation_idx(indices) src_dims = outputs["size_3d"][idx] target_dims = torch.cat([t["size_3d"][i] for t, (_, i) in zip(targets, indices)], dim=0) dimension = target_dims.clone().detach() dim_loss = torch.abs(src_dims - target_dims) dim_loss /= dimension with torch.no_grad(): compensation_weight = functional.l1_loss(src_dims, target_dims) / dim_loss.mean() dim_loss *= compensation_weight return {"loss_dim": dim_loss.sum() / num_boxes}
[docs] def loss_angles(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Compute the loss for the angle prediction. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch. """ idx = self._get_src_permutation_idx(indices) heading_input = outputs["heading_angle"][idx] target_heading_angle = torch.cat([t["heading_angle"][i] for t, (_, i) in zip(targets, indices)], dim=0) heading_target_cls = target_heading_angle[:, 0].view(-1).long() heading_target_res = target_heading_angle[:, 1].view(-1) heading_input = heading_input.view(-1, 24) # classification loss heading_input_cls = heading_input[:, 0:12] cls_loss = functional.cross_entropy(heading_input_cls, heading_target_cls, reduction="none") # regression loss heading_input_res = heading_input[:, 12:24] cls_onehot = ( torch.zeros(heading_target_cls.shape[0], 12) .to(device=heading_input.device) .scatter_(dim=1, index=heading_target_cls.view(-1, 1), value=1) ) heading_input_res = torch.sum(heading_input_res * cls_onehot, 1) reg_loss = functional.l1_loss(heading_input_res, heading_target_res, reduction="none") angle_loss = cls_loss + reg_loss return {"loss_angle": angle_loss.sum() / num_boxes}
[docs] def loss_depth_map(self, outputs: dict, targets: list, indices: list, num_boxes: int) -> dict[str, Tensor]: """Depth map loss. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. indices (list): list of tuples, such that len(indices) == batch_size. num_boxes (int): number of boxes in the batch. """ depth_map_logits = outputs["pred_depth_map_logits"] num_gt_per_img = [len(t["boxes"]) for t in targets] gt_boxes2d = torch.cat([t["boxes"] for t in targets], dim=0) * torch.tensor( [80, 24, 80, 24], device=depth_map_logits.device, ) gt_boxes2d = box_convert(gt_boxes2d, "cxcywh", "xyxy") gt_center_depth = torch.cat([t["depth"] for t in targets], dim=0).squeeze(dim=1) return {"loss_depth_map": self.ddn_loss(depth_map_logits, gt_boxes2d, num_gt_per_img, gt_center_depth)}
def _get_src_permutation_idx( self, indices: list[tuple[torch.Tensor, torch.Tensor]], ) -> tuple[torch.Tensor, torch.Tensor]: """Get the indices necessary to compute the loss.""" # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx( self, indices: list[tuple[torch.Tensor, torch.Tensor]], ) -> tuple[torch.Tensor, torch.Tensor]: """Get the indices necessary to compute the loss.""" # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx @property def loss_map(self) -> dict[str, Callable]: """Return the loss map.""" return { "loss_ce": self.loss_labels, "loss_bbox": self.loss_boxes, "loss_giou": self.loss_giou, "loss_depth": self.loss_depths, "loss_dim": self.loss_dims, "loss_angle": self.loss_angles, "loss_center": self.loss_3dcenter, "loss_depth_map": self.loss_depth_map, }
[docs] def forward( self, outputs: dict[str, torch.Tensor], targets: list[dict[str, torch.Tensor]], ) -> dict[str, torch.Tensor]: """This performs the loss computation. Args: outputs (dict): dict of tensors, see the output specification of the model for the format. targets (list): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc. """ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} group_num = self.group_num if self.training else 1 # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets, group_num=group_num) # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes_int = sum([len(t["labels"]) for t in targets]) * group_num num_boxes = torch.as_tensor([num_boxes_int], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.clamp(num_boxes, min=1) # Compute all the requested losses losses = {} for loss in self.loss_map.values(): losses.update(loss(outputs, targets, indices, num_boxes)) losses = {k: losses[k] * self.weight_dict[k] for k in losses} # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): indices = self.matcher(aux_outputs, targets, group_num=group_num) for name, loss in self.loss_map.items(): if name == "loss_depth_map": # Intermediate masks losses are too costly to compute, we ignore them. continue l_dict = loss(aux_outputs, targets, indices, num_boxes.item()) l_dict = {k + f"_aux_{i}": v * self.weight_dict[k] for k, v in l_dict.items()} losses.update(l_dict) return losses