Source code for otx.data.torch
"""Torch-specific data item implementations."""
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections.abc import Iterator, Mapping
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Any
import torch
from otx.core.data.entity.utils import register_pytree_node
from .validations import (
ValidateBatchMixin,
ValidateItemMixin,
)
if TYPE_CHECKING:
from torchvision.tv_tensors import BoundingBoxes, Mask
from otx.core.data.entity.base import ImageInfo
# NOTE: register_pytree_node and Mapping are required for torchvision.transforms.v2 to work with OTXDataEntity
# TODO(ashwinvaidya17): Remove this once custom transforms are removed
[docs]
@register_pytree_node
@dataclass
class TorchDataItem(ValidateItemMixin, Mapping):
"""Torch data item implementation.
Attributes:
image (torch.Tensor): The image tensor.
label (torch.Tensor | None): The label tensor, optional.
masks (Mask | None): The masks, optional.
bboxes (BoundingBoxes | None): The bounding boxes, optional.
keypoints (torch.Tensor | None): The keypoints, optional.
img_info (ImageInfo | None): Additional image information, optional.
"""
image: torch.Tensor
label: torch.Tensor | None = None
masks: Mask | None = None
bboxes: BoundingBoxes | None = None
keypoints: torch.Tensor | None = None
img_info: ImageInfo | None = None # TODO(ashwinvaidya17): revisit and try to remove this
[docs]
@staticmethod
def collate_fn(items: list[TorchDataItem]) -> TorchDataBatch:
"""Collate TorchDataItems into a batch.
Args:
items: List of TorchDataItems to batch
Returns:
Batched TorchDataItems with stacked tensors
"""
# Check if all images have the same size. TODO(kprokofi): remove this check once OV IR models are moved.
if all(item.image.shape == items[0].image.shape for item in items):
images = torch.stack([item.image for item in items])
else:
# we need this only in case of OV inference, where no resize
images = [item.image for item in items]
return TorchDataBatch(
batch_size=len(items),
images=images,
labels=[item.label for item in items],
bboxes=[item.bboxes for item in items],
keypoints=[item.keypoints for item in items],
masks=[item.masks for item in items],
imgs_info=[item.img_info for item in items],
)
def __iter__(self) -> Iterator[str]:
for field_ in fields(self):
yield field_.name
def __getitem__(self, key: str) -> Any: # noqa: ANN401
return getattr(self, key)
def __len__(self) -> int:
return len(fields(self))
[docs]
@dataclass
class TorchDataBatch(ValidateBatchMixin):
"""Torch data item batch implementation."""
batch_size: int # TODO(ashwinvaidya17): Remove this
images: torch.Tensor | list[torch.Tensor]
labels: list[torch.Tensor] | None = None
masks: list[Mask] | None = None
bboxes: list[BoundingBoxes] | None = None
keypoints: list[torch.Tensor] | None = None
imgs_info: list[ImageInfo | None] | None = None # TODO(ashwinvaidya17): revisit
[docs]
@dataclass
class TorchPredItem(TorchDataItem):
"""Torch prediction data item implementation."""
scores: torch.Tensor | None = None
feature_vector: torch.Tensor | None = None
saliency_map: torch.Tensor | None = None
[docs]
@dataclass
class TorchPredBatch(TorchDataBatch):
"""Torch prediction data item batch implementation."""
scores: list[torch.Tensor] | None = None
feature_vector: list[torch.Tensor] | None = None
saliency_map: list[torch.Tensor] | None = None
@property
def has_xai_outputs(self) -> bool:
"""Check if the batch has XAI outputs.
Necessary for compatibility with tests.
"""
# TODO(ashwinvaidya17): the tests should directly refer to saliency map.
return self.saliency_map is not None and len(self.saliency_map) > 0