# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition for 3d object detection model entity used in OTX."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np
import torch
from model_api.models import ImageModel
from torchvision.ops import box_convert
from otx.algo.object_detection_3d.utils.utils import box_cxcylrtb_to_xyxy
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity
from otx.core.data.entity.object_detection_3d import Det3DBatchDataEntity, Det3DBatchPredEntity
from otx.core.metrics import MetricInput
from otx.core.metrics.average_precision_3d import KittiMetric
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.types.export import TaskLevelExportParameters
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from model_api.adapters.inference_adapter import InferenceAdapter
from torch import nn
from otx.core.metrics import MetricCallable
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes
[docs]
class OTX3DDetectionModel(OTXModel[Det3DBatchDataEntity, Det3DBatchPredEntity]):
"""Base class for the 3d detection models used in OTX."""
mean: tuple[float, float, float]
std: tuple[float, float, float]
load_from: str | None
def __init__(
self,
label_info: LabelInfoTypes,
model_name: str,
input_size: tuple[int, int],
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = KittiMetric,
torch_compile: bool = False,
score_threshold: float = 0.1,
) -> None:
"""Initialize the 3d detection model."""
self.model_name = model_name
self.score_threshold = score_threshold
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
def _create_model(self) -> nn.Module:
"""Creates the model."""
detector = self._build_model(num_classes=self.label_info.num_classes)
if hasattr(detector, "init_weights"):
detector.init_weights()
self.classification_layers = self.get_classification_layers(prefix="model.")
if self.load_from is not None:
load_checkpoint(detector, self.load_from, map_location="cpu")
return detector
@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
return super()._export_parameters.wrap(
model_type="mono_3d_det",
task_type="3d_detection",
)
def _customize_inputs(
self,
entity: Det3DBatchDataEntity,
) -> dict[str, Any]:
# prepare bboxes for the model
targets_list = []
img_sizes = torch.from_numpy(np.array([img_info.ori_shape for img_info in entity.imgs_info])).to(
device=entity.images.device,
)
key_list = ["labels", "boxes", "depth", "size_3d", "heading_angle", "boxes_3d"]
for bz in range(len(entity.imgs_info)):
target_dict = {}
for key in key_list:
target_dict[key] = getattr(entity, key)[bz]
targets_list.append(target_dict)
return {
"images": entity.images,
"calibs": torch.cat([p2.unsqueeze(0) for p2 in entity.calib_matrix], dim=0),
"targets": targets_list,
"img_sizes": img_sizes,
"mode": "loss" if self.training else "predict",
}
def _customize_outputs(
self,
outputs: dict[str, torch.Tensor],
inputs: Det3DBatchDataEntity,
) -> Det3DBatchPredEntity | OTXBatchLossEntity:
if self.training:
if not isinstance(outputs, dict):
raise TypeError(outputs)
losses = OTXBatchLossEntity()
for k, v in outputs.items():
if isinstance(v, list):
losses[k] = sum(v)
elif isinstance(v, torch.Tensor):
losses[k] = v
else:
msg = "Loss output should be list or torch.tensor but got {type(v)}"
raise TypeError(msg)
return losses
labels, scores, size_3d, heading_angle, boxes_3d, depth = self.extract_dets_from_outputs(outputs)
# bbox 2d decoding
boxes_2d = box_cxcylrtb_to_xyxy(boxes_3d)
xywh_2d = box_convert(boxes_2d, "xyxy", "cxcywh")
# size 2d decoding
size_2d = xywh_2d[:, :, 2:4]
return Det3DBatchPredEntity(
batch_size=inputs.batch_size,
images=inputs.images,
imgs_info=inputs.imgs_info,
calib_matrix=inputs.calib_matrix,
boxes=boxes_2d,
labels=labels,
boxes_3d=boxes_3d,
size_2d=size_2d,
size_3d=size_3d,
depth=depth,
heading_angle=heading_angle,
scores=scores,
original_kitti_format=[None],
)
def _convert_pred_entity_to_compute_metric(
self,
preds: Det3DBatchPredEntity,
inputs: Det3DBatchDataEntity,
) -> MetricInput:
return _convert_pred_entity_to_compute_metric(preds, inputs, self.label_info.label_names, self.score_threshold)
[docs]
def get_classification_layers(self, prefix: str = "model.") -> dict[str, dict[str, int]]:
"""Get final classification layer information for incremental learning case."""
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
classification_layers = {}
for key in sample_model_dict:
if sample_model_dict[key].shape != incremental_model_dict[key].shape:
sample_model_dim = sample_model_dict[key].shape[0]
incremental_model_dim = incremental_model_dict[key].shape[0]
stride = incremental_model_dim - sample_model_dim
num_extra_classes = 6 * sample_model_dim - 5 * incremental_model_dim
classification_layers[prefix + key] = {"stride": stride, "num_extra_classes": num_extra_classes}
return classification_layers
[docs]
class MonoDETRModel(ImageModel):
"""A wrapper for MonoDETR 3d object detection model."""
__model__ = "mono_3d_det"
def __init__(self, inference_adapter: InferenceAdapter, configuration: dict[str, Any], preload: bool = False):
"""Initializes a 3d detection model.
Args:
inference_adapter (InferenceAdapter): inference adapter containing the underlying model.
configuration (dict, optional): configuration overrides the model parameters (see parameters() method).
preload (bool, optional): forces inference adapter to load the model. Defaults to False.
"""
super().__init__(inference_adapter, configuration, preload)
self._check_io_number(3, 5)
[docs]
def preprocess(self, inputs: dict[str, np.ndarray]) -> tuple[dict[str, Any], ...]:
"""Preprocesses the input data for the model.
Args:
inputs (dict[str, np.ndarray]): a dict with image, calibration matrix, and image size
Returns:
tuple[dict[str, Any], ...]: a tuple with the preprocessed inputs and meta information
"""
return {
self.image_blob_name: inputs["image"][None],
"calib_matrix": inputs["calib"],
"img_sizes": inputs["img_size"][None],
}, {
"original_shape": inputs["image"].shape,
"resized_shape": (self.h, self.w, self.c),
}
def _get_inputs(self) -> tuple[list[Any], list[Any]]:
"""Defines the model inputs for images and additional info.
Raises:
WrapperError: if the wrapper failed to define appropriate inputs for images
Returns:
- list of inputs names for images
- list of inputs names for additional info
"""
image_blob_names, image_info_blob_names = [], []
for name, metadata in self.inputs.items():
if len(metadata.shape) == 4:
image_blob_names.append(name)
elif len(metadata.shape) == 2:
image_info_blob_names.append(name)
if not image_blob_names:
self.raise_error(
"Failed to identify the input for the image: no 4D input layer found",
)
return image_blob_names, image_info_blob_names
[docs]
def postprocess(
self,
outputs: dict[str, np.ndarray],
meta: dict[str, Any],
) -> dict[str, Any]:
"""Applies SCC decoded to the model outputs.
Args:
outputs (dict[str, np.ndarray]): raw outputs of the model
meta (dict[str, Any]): meta information about the input data
Returns:
dict[str, Any]: postprocessed model outputs
"""
result = {}
for k in outputs:
result[k] = np.copy(outputs[k])
return result
[docs]
class OV3DDetectionModel(OVModel[Det3DBatchDataEntity, Det3DBatchPredEntity]):
"""3d detection model compatible for OpenVINO IR inference.
It can consume OpenVINO IR model path or model name from Intel OMZ repository
and create the OTX 3d detection model compatible for OTX testing pipeline.
"""
def __init__(
self,
model_name: str,
model_type: str = "mono_3d_det",
async_inference: bool = True,
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = KittiMetric,
score_threshold: float = 0.2,
**kwargs,
) -> None:
super().__init__(
model_name=model_name,
model_type=model_type,
async_inference=async_inference,
max_num_requests=max_num_requests,
use_throughput_mode=use_throughput_mode,
model_api_configuration=model_api_configuration,
metric=metric,
)
self.score_threshold = score_threshold
def _customize_inputs(
self,
entity: Det3DBatchDataEntity,
) -> dict[str, Any]:
img_sizes = np.array([img_info.ori_shape for img_info in entity.imgs_info])
images = [np.transpose(im.cpu().numpy(), (1, 2, 0)) for im in entity.images]
return {
"images": images,
"calibs": [p2.unsqueeze(0).cpu().numpy() for p2 in entity.calib_matrix],
"targets": [],
"img_sizes": img_sizes,
"mode": "predict",
}
def _customize_outputs(
self,
outputs: list[NamedTuple],
inputs: Det3DBatchDataEntity,
) -> Det3DBatchPredEntity | OTXBatchLossEntity:
stacked_outputs: dict[str, Any] = {}
for output in outputs:
for k in output:
if k in stacked_outputs:
stacked_outputs[k] = torch.cat((stacked_outputs[k], torch.tensor(output[k])), 0)
else:
stacked_outputs[k] = torch.tensor(output[k])
labels, scores, size_3d, heading_angle, boxes_3d, depth = self.extract_dets_from_outputs(stacked_outputs)
# bbox 2d decoding
boxes_2d = box_cxcylrtb_to_xyxy(boxes_3d)
xywh_2d = box_convert(boxes_2d, "xyxy", "cxcywh")
# size 2d decoding
size_2d = xywh_2d[:, :, 2:4]
return Det3DBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
calib_matrix=inputs.calib_matrix,
boxes=boxes_2d,
labels=labels,
boxes_3d=boxes_3d,
size_2d=size_2d,
size_3d=size_3d,
depth=depth,
heading_angle=heading_angle,
scores=scores,
original_kitti_format=[None],
)
def _forward(self, inputs: Det3DBatchDataEntity) -> Det3DBatchPredEntity:
"""Model forward function."""
all_inputs = self._customize_inputs(inputs)
model_ready_inputs = []
for image, calib, img_size in zip(all_inputs["images"], all_inputs["calibs"], all_inputs["img_sizes"]):
model_ready_inputs.append(
{
"image": image,
"calib": calib,
"img_size": img_size,
},
)
if self.async_inference:
outputs = self.model.infer_batch(model_ready_inputs)
else:
outputs = []
for model_input in model_ready_inputs:
outputs.append(self.model(model_input))
customized_outputs = self._customize_outputs(outputs, inputs)
if isinstance(customized_outputs, OTXBatchLossEntity):
raise TypeError(customized_outputs)
return customized_outputs
def _convert_pred_entity_to_compute_metric(
self,
preds: Det3DBatchPredEntity,
inputs: Det3DBatchDataEntity,
) -> MetricInput:
return _convert_pred_entity_to_compute_metric(preds, inputs, self.label_info.label_names, self.score_threshold)
def _convert_pred_entity_to_compute_metric(
preds: Det3DBatchPredEntity,
inputs: Det3DBatchDataEntity,
label_names: list[str],
score_threshold: float,
) -> MetricInput:
"""Converts the prediction entity to the format required for computing metrics.
Args:
preds (Det3DBatchPredEntity): Prediction entity.
inputs (Det3DBatchDataEntity): Input data entity.
label_names (list[str]): List of label names.
score_threshold (float): Score threshold for filtering the predictions.
"""
boxes = preds.boxes_3d
# bbox 2d decoding
xywh_2d = box_convert(preds.boxes, "xyxy", "cxcywh")
xs3d = boxes[:, :, 0:1]
ys3d = boxes[:, :, 1:2]
xs2d = xywh_2d[:, :, 0:1]
ys2d = xywh_2d[:, :, 1:2]
batch = len(boxes)
labels = preds.labels.view(batch, -1, 1)
scores = preds.scores.view(batch, -1, 1)
xs2d = xs2d.view(batch, -1, 1)
ys2d = ys2d.view(batch, -1, 1)
xs3d = xs3d.view(batch, -1, 1)
ys3d = ys3d.view(batch, -1, 1)
detections = (
torch.cat(
[
labels,
scores,
xs2d,
ys2d,
preds.size_2d,
preds.depth[:, :, 0:1],
preds.heading_angle,
preds.size_3d,
xs3d,
ys3d,
torch.exp(-preds.depth[:, :, 1:2]),
],
dim=2,
)
.detach()
.cpu()
.numpy()
)
img_sizes = np.array([img_info.ori_shape for img_info in inputs.imgs_info])
calib_matrix = [p2.detach().cpu().numpy() for p2 in inputs.calib_matrix]
result_list = OTX3DDetectionModel.decode_detections_for_kitti_format(
detections,
img_sizes,
calib_matrix,
class_names=label_names,
threshold=score_threshold,
)
return {
"preds": result_list,
"target": inputs.original_kitti_format, # type: ignore[dict-item]
}
def _generate_dummy_input(input_size: tuple[int, ...], batch_size: int = 1) -> Det3DBatchDataEntity:
"""Returns a dummy input for 3d object detection model."""
images = torch.rand(batch_size, 3, *input_size)
calib_matrix = [torch.rand(3, 4) for _ in range(batch_size)]
infos = []
for i, img in enumerate(images):
infos.append(
ImageInfo(
img_idx=i,
img_shape=img.shape[1:],
ori_shape=img.shape[1:],
),
)
return Det3DBatchDataEntity(
batch_size,
images,
infos,
boxes=[torch.Tensor(0)] * batch_size,
labels=[torch.LongTensor(0)] * batch_size,
calib_matrix=calib_matrix,
boxes_3d=[torch.LongTensor(0)] * batch_size,
size_2d=[],
size_3d=[torch.LongTensor(0)] * batch_size,
depth=[torch.LongTensor(0)] * batch_size,
heading_angle=[torch.LongTensor(0)] * batch_size,
original_kitti_format=[],
)