Source code for otx.algo.object_detection_3d.utils.utils

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""utils for object detection 3D models."""
from __future__ import annotations

import torch
from torch import Tensor


# TODO(Kirill): try to remove this class
[docs] class NestedTensor: """Nested tensor class for object detection 3D models.""" def __init__(self, tensors: Tensor, mask: Tensor) -> None: """Initialize a NestedTensor object. Args: tensors (Tensor): The tensors representing the nested structure. mask (Tensor): The mask indicating the valid elements in the tensors. """ self.tensors = tensors self.mask = mask
[docs] def to(self, device: torch.device) -> NestedTensor: """Move the NestedTensor object to the specified device. Args: device: The device to move the tensors to. Returns: NestedTensor: The NestedTensor object with tensors moved to the specified device. """ cast_tensor = self.tensors.to(device) cast_mask = self.mask.to(device) if self.mask is not None else None return NestedTensor(cast_tensor, cast_mask)
[docs] def decompose(self) -> tuple[Tensor, Tensor]: """Decompose the NestedTensor object into its constituent tensors and masks.""" return self.tensors, self.mask
def __repr__(self) -> str: """Return a string representation of the NestedTensor object.""" return str(self.tensors)
[docs] def box_cxcylrtb_to_xyxy(x: Tensor) -> Tensor: """Transform bbox from cxcylrtb to xyxy representation.""" x_c, y_c, k, r, t, b = x.unbind(-1) bb = [(x_c - k), (y_c - t), (x_c + r), (y_c + b)] return torch.stack(bb, dim=-1)