Source code for otx.algo.object_detection_3d.detectors.monodetr

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""MonoDetr core Pytorch detector."""
from __future__ import annotations

import math
from typing import Callable

import torch
from torch import Tensor, nn
from torch.nn import functional

from otx.algo.common.layers.transformer_layers import MLP
from otx.algo.common.utils.utils import get_clones, inverse_sigmoid
from otx.algo.object_detection_3d.utils.utils import NestedTensor


# TODO (Kirill): make MonoDETR as a more general class
[docs] class MonoDETR(nn.Module): """This is the MonoDETR module that performs monocualr 3D object detection.""" def __init__( self, backbone: nn.Module, depthaware_transformer: nn.Module, depth_predictor: nn.Module, num_classes: int, num_queries: int, num_feature_levels: int, criterion: nn.Module | None = None, aux_loss: bool = True, with_box_refine: bool = False, init_box: bool = False, group_num: int = 11, activation: Callable[..., nn.Module] = nn.ReLU, ): """Initializes the model. Args: backbone (nn.Module): torch module of the backbone to be used. See backbone.py depthaware_transformer (nn.Module): depth-aware transformer architecture. See depth_aware_transformer.py depth_predictor (nn.Module): depth predictor module criterion (nn.Module | None): loss criterion module num_classes (int): number of object classes num_queries (int): number of object queries, ie detection slot. This is the maximal number of objects DETR can detect in a single image. For KITTI, we recommend 50 queries. num_feature_levels (int): number of feature levels aux_loss (bool): True if auxiliary decoding losses (loss at each decoder layer) are to be used. with_box_refine (bool): iterative bounding box refinement init_box (bool): True if the bounding box embedding layers should be initialized to zero group_num (int): number of groups for depth-aware bounding box embedding activation (Callable[..., nn.Module]): activation function to be applied to the output of the transformer """ super().__init__() self.num_queries = num_queries self.depthaware_transformer = depthaware_transformer self.depth_predictor = depth_predictor hidden_dim = depthaware_transformer.d_model self.hidden_dim = hidden_dim self.num_feature_levels = num_feature_levels self.criterion = criterion self.label_enc = nn.Embedding(num_classes + 1, hidden_dim - 1) # # for indicator # prediction heads self.class_embed = nn.Linear(hidden_dim, num_classes) prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) self.class_embed.bias.data = torch.ones(num_classes) * bias_value self.bbox_embed = MLP(hidden_dim, hidden_dim, 6, 3, activation=activation) self.dim_embed_3d = MLP(hidden_dim, hidden_dim, 3, 2, activation=activation) self.angle_embed = MLP(hidden_dim, hidden_dim, 24, 2, activation=activation) self.depth_embed = MLP(hidden_dim, hidden_dim, 2, 2, activation=activation) # depth and deviation if init_box: nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) self.query_embed = nn.Embedding(num_queries * group_num, hidden_dim * 2) if num_feature_levels > 1: num_backbone_outs = len(backbone.strides) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ), ) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, hidden_dim), ), ) in_channels = hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList( [ nn.Sequential( nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ), ], ) self.backbone = backbone self.aux_loss = aux_loss self.with_box_refine = with_box_refine self.num_classes = num_classes for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) # if two-stage, the last class_embed and bbox_embed is for region proposal generation num_pred = depthaware_transformer.decoder.num_layers if with_box_refine: self.class_embed = get_clones(self.class_embed, num_pred) self.bbox_embed = get_clones(self.bbox_embed, num_pred) nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) # implementation for iterative bounding box refinement self.depthaware_transformer.decoder.bbox_embed = self.bbox_embed self.dim_embed_3d = get_clones(self.dim_embed_3d, num_pred) self.depthaware_transformer.decoder.dim_embed = self.dim_embed_3d self.angle_embed = get_clones(self.angle_embed, num_pred) self.depth_embed = get_clones(self.depth_embed, num_pred) else: nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) self.dim_embed_3d = nn.ModuleList([self.dim_embed_3d for _ in range(num_pred)]) self.angle_embed = nn.ModuleList([self.angle_embed for _ in range(num_pred)]) self.depth_embed = nn.ModuleList([self.depth_embed for _ in range(num_pred)]) self.depthaware_transformer.decoder.bbox_embed = None
[docs] def forward( self, images: Tensor, calibs: Tensor, img_sizes: Tensor, targets: list[dict[str, Tensor]] | None = None, mode: str = "predict", ) -> dict[str, Tensor]: """Forward method of the MonoDETR model. Args: images (Tensor): images for each sample. calibs (Tensor): camera matrices for each sample. img_sizes (Tensor): image sizes for each sample. targets (list[dict[str, Tensor]): ground truth boxes and labels for each sample. Defaults to None. mode (str): The mode of operation. Defaults to "predict". Returns: dict[str, Tensor]: A dictionary of tensors. If mode is "loss", the tensors are the loss values. If mode is "predict", the tensors are the logits. """ features, pos = self.backbone(images) srcs = [] masks = [] for i, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.input_proj[i](src)) masks.append(mask) if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for i in range(_len_srcs, self.num_feature_levels): src = self.input_proj[i](features[-1].tensors) if i == _len_srcs else self.input_proj[i](srcs[-1]) m = torch.zeros(src.shape[0], src.shape[2], src.shape[3]).to(torch.bool).to(src.device) mask = functional.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) pos.append(pos_l) query_embeds = self.query_embed.weight if self.training else self.query_embed.weight[: self.num_queries] pred_depth_map_logits, depth_pos_embed, weighted_depth, depth_pos_embed_ip = self.depth_predictor( srcs, masks[1], pos[1], ) ( hs, init_reference, inter_references, inter_references_dim, enc_outputs_class, enc_outputs_coord_unact, ) = self.depthaware_transformer( srcs, masks, pos, query_embeds, depth_pos_embed, depth_pos_embed_ip, ) outputs_coords = [] outputs_classes = [] outputs_3d_dims = [] outputs_depths = [] outputs_angles = [] for lvl in range(hs.shape[0]): reference = init_reference if lvl == 0 else inter_references[lvl - 1] reference = inverse_sigmoid(reference) tmp = self.bbox_embed[lvl](hs[lvl]) if reference.shape[-1] == 6: tmp += reference else: tmp[..., :2] += reference # 3d center + 2d box outputs_coord = tmp.sigmoid() outputs_coords.append(outputs_coord) # classes outputs_class = self.class_embed[lvl](hs[lvl]) outputs_classes.append(outputs_class) # 3D sizes size3d = inter_references_dim[lvl] outputs_3d_dims.append(size3d) # depth_geo box2d_height_norm = outputs_coord[:, :, 4] + outputs_coord[:, :, 5] box2d_height = torch.clamp(box2d_height_norm * img_sizes[:, :1], min=1.0) depth_geo = size3d[:, :, 0] / box2d_height * calibs[:, 0, 0].unsqueeze(1) # depth_reg depth_reg = self.depth_embed[lvl](hs[lvl]) # depth_map outputs_center3d = ((outputs_coord[..., :2] - 0.5) * 2).unsqueeze(2).detach() depth_map = functional.grid_sample( weighted_depth.unsqueeze(1), outputs_center3d, mode="bilinear", align_corners=True, ).squeeze(1) # depth average + sigma depth_ave = torch.cat( [ ((1.0 / (depth_reg[:, :, 0:1].sigmoid() + 1e-6) - 1.0) + depth_geo.unsqueeze(-1) + depth_map) / 3, depth_reg[:, :, 1:2], ], -1, ) outputs_depths.append(depth_ave) # angles outputs_angle = self.angle_embed[lvl](hs[lvl]) outputs_angles.append(outputs_angle) outputs_coord = torch.stack(outputs_coords) outputs_class = torch.stack(outputs_classes) outputs_3d_dim = torch.stack(outputs_3d_dims) outputs_depth = torch.stack(outputs_depths) outputs_angle = torch.stack(outputs_angles) out = {"scores": outputs_class[-1], "boxes_3d": outputs_coord[-1]} out["size_3d"] = outputs_3d_dim[-1] out["depth"] = outputs_depth[-1] out["heading_angle"] = outputs_angle[-1] if mode == "export": out["scores"] = out["scores"].sigmoid() return out out["pred_depth_map_logits"] = pred_depth_map_logits if self.aux_loss: out["aux_outputs"] = self._set_aux_loss( outputs_class, outputs_coord, outputs_3d_dim, outputs_angle, outputs_depth, ) if mode == "loss": if self.criterion is None: msg = "Criterion is not set for the model" raise ValueError(msg) return self.criterion(outputs=out, targets=targets) return out
@torch.jit.unused def _set_aux_loss( self, outputs_class: Tensor, outputs_coord: Tensor, outputs_3d_dim: Tensor, outputs_angle: Tensor, outputs_depth: Tensor, ) -> list[dict[str, Tensor]]: # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. return [ {"scores": a, "boxes_3d": b, "size_3d": c, "heading_angle": d, "depth": e} for a, b, c, d, e in zip( outputs_class[:-1], outputs_coord[:-1], outputs_3d_dim[:-1], outputs_angle[:-1], outputs_depth[:-1], ) ]