Source code for otx.algo.detection.losses.yolox_loss
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
#
"""YOLOX criterion."""
from __future__ import annotations
from torch import Tensor, nn
from otx.algo.common.losses import CrossEntropyLoss, IoULoss, L1Loss
[docs]
class YOLOXCriterion(nn.Module):
"""YOLOX criterion module.
This module calculates the loss for YOLOX object detection model.
Args:
num_classes (int): The number of classes.
loss_cls (nn.Module | None): The classification loss module. Defaults to None.
loss_bbox (nn.Module | None): The bounding box regression loss module. Defaults to None.
loss_obj (nn.Module | None): The objectness loss module. Defaults to None.
loss_l1 (nn.Module | None): The L1 loss module. Defaults to None.
Returns:
dict[str, Tensor]: A dictionary containing the calculated losses.
"""
def __init__(
self,
num_classes: int,
loss_cls: nn.Module | None = None,
loss_bbox: nn.Module | None = None,
loss_obj: nn.Module | None = None,
loss_l1: nn.Module | None = None,
use_l1: bool = False,
) -> None:
super().__init__()
self.num_classes = num_classes
self.loss_cls = loss_cls or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0)
self.loss_bbox = loss_bbox or IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0)
self.loss_obj = loss_obj or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0)
self.loss_l1 = loss_l1 or L1Loss(reduction="sum", loss_weight=1.0)
self.use_l1 = use_l1
[docs]
def forward(
self,
flatten_objectness: Tensor,
flatten_cls_preds: Tensor,
flatten_bbox_preds: Tensor,
flatten_bboxes: Tensor,
obj_targets: Tensor,
cls_targets: Tensor,
bbox_targets: Tensor,
l1_targets: Tensor,
num_total_samples: Tensor,
num_pos: Tensor,
pos_masks: Tensor,
) -> dict[str, Tensor]:
"""Forward pass of the YOLOX criterion module.
Args:
flatten_objectness (Tensor): Flattened objectness predictions.
flatten_cls_preds (Tensor): Flattened class predictions.
flatten_bbox_preds (Tensor): Flattened bounding box predictions.
flatten_bboxes (Tensor): Flattened ground truth bounding boxes.
obj_targets (Tensor): Objectness targets.
cls_targets (Tensor): Class targets.
bbox_targets (Tensor): Bounding box targets.
l1_targets (Tensor): L1 targets.
num_total_samples (Tensor): Total number of samples.
num_pos (Tensor): Number of positive samples.
pos_masks (Tensor): Positive masks.
Returns:
dict[str, Tensor]: A dictionary containing the calculated losses.
"""
loss_obj = self.loss_obj(flatten_objectness.view(-1, 1), obj_targets) / num_total_samples
if num_pos > 0:
loss_cls = (
self.loss_cls(flatten_cls_preds.view(-1, self.num_classes)[pos_masks], cls_targets) / num_total_samples
)
loss_bbox = self.loss_bbox(flatten_bboxes.view(-1, 4)[pos_masks], bbox_targets) / num_total_samples
else:
# Avoid cls and reg branch not participating in the gradient
# propagation when there is no ground-truth in the images.
# For more details, please refer to
# https://github.com/open-mmlab/mmdetection/issues/7298
loss_cls = flatten_cls_preds.sum() * 0
loss_bbox = flatten_bboxes.sum() * 0
loss_dict = {"loss_cls": loss_cls, "loss_bbox": loss_bbox, "loss_obj": loss_obj}
if self.use_l1:
if num_pos > 0:
loss_l1 = self.loss_l1(flatten_bbox_preds.view(-1, 4)[pos_masks], l1_targets) / num_total_samples
else:
# Avoid cls and reg branch not participating in the gradient
# propagation when there is no ground-truth in the images.
# For more details, please refer to
# https://github.com/open-mmlab/mmdetection/issues/7298
loss_l1 = flatten_bbox_preds.sum() * 0
loss_dict.update(loss_l1=loss_l1)
return loss_dict