# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""OTX tile merge module."""
from __future__ import annotations
from abc import abstractmethod
from collections import defaultdict
from typing import Callable
import cv2
import numpy as np
import torch
from packaging import version
from torchvision import tv_tensors
from torchvision.ops import batched_nms
from otx.algo.explain.explain_algo import InstSegExplainAlgo
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity
from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegPredEntity
from otx.data import TorchPredBatch, TorchPredItem
# Maximum number of elements 2**31 -1
MAX_ELEMENTS: int = np.iinfo(np.int32).max
# NOTE: RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements,
# See https://github.com/pytorch/pytorch/issues/51871
int_max_check_condition: Callable[[torch.Tensor], bool] = (
lambda tile_masks: version.parse(torch.__version__) < version.parse("2.6")
and torch.numel(tile_masks) > MAX_ELEMENTS
)
[docs]
def keep_chunkify(tensor: torch.Tensor, max_element: int = MAX_ELEMENTS) -> torch.Tensor:
"""Splits tensor into chunks and processes each chunk separately.
Args:
tensor (torch.Tensor): Input tensor of shape (B, H, W).
Returns:
torch.Tensor: Boolean mask of shape (B,) indicating nonzero sum.
"""
_, h, w = tensor.shape
max_batch_size = int(max_element) // (h * w)
chunk_size = max(1, min(max_batch_size, tensor.shape[0]))
keep_indices = []
for i in range(0, tensor.shape[0], chunk_size):
chunk = tensor[i : i + chunk_size]
keep_indices.append(chunk.sum(dim=(1, 2)) > 0) # Process chunk
return torch.cat(keep_indices, dim=0)
[docs]
class TileMerge:
"""Base class for tile merge.
Args:
img_infos (list[ImageInfo]): Original image information before tiling.
num_classes (int): Number of classes.
tile_config (TileConfig): Tile configuration.
explain_mode (bool, optional): Whether or not tiles have explain features. Default: False.
"""
def __init__(
self,
img_infos: list[ImageInfo],
num_classes: int,
tile_config: TileConfig,
explain_mode: bool = False,
) -> None:
self.img_infos = img_infos
self.num_classes = num_classes
self.tile_size = tile_config.tile_size
self.iou_threshold = tile_config.iou_threshold
self.max_num_instances = tile_config.max_num_instances
self.with_full_img = tile_config.with_full_img
self.explain_mode = explain_mode
@abstractmethod
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[T_OTXDataEntity],
explain_mode: bool = False,
) -> T_OTXDataEntity:
"""Merge tile predictions to one single full-size prediction data entity.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[T_OTXDataEntity]): List of tile prediction entities.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
Returns:
T_OTXDataEntity: Merged prediction entity.
"""
raise NotImplementedError
[docs]
@abstractmethod
def merge(
self,
batch_tile_preds: list[T_OTXBatchPredEntity],
batch_tile_attrs: list[list[dict]],
) -> list[T_OTXDataEntity]:
"""Merge batch tile predictions to a list of full-size prediction data entities.
Args:
batch_tile_preds (list): list of tile predictions.
batch_tile_attrs (list): list of tile attributes.
"""
raise NotImplementedError
[docs]
def nms_postprocess(
self,
bboxes: torch.Tensor,
scores: torch.Tensor,
labels: torch.Tensor,
masks: None | list[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None | torch.Tensor]:
"""Non-maximum suppression and post-process."""
keep = batched_nms(bboxes, scores, labels, self.iou_threshold)
if len(keep) > self.max_num_instances:
keep = keep[: self.max_num_instances]
bboxes = bboxes[keep]
labels = labels[keep]
scores = scores[keep]
if masks is not None and len(masks) > 0:
# coalesce sparse tensors to prevent them from growing too large.
masks = torch.stack([masks[idx] for idx in keep]).coalesce().to_dense()
return bboxes, labels, scores, masks
[docs]
class DetectionTileMerge(TileMerge):
"""Detection tile merge."""
[docs]
def merge(
self,
batch_tile_preds: list[DetBatchPredEntity],
batch_tile_attrs: list[list[dict]],
) -> list[DetPredEntity]:
"""Merge batch tile predictions to a list of full-size prediction data entities.
Args:
batch_tile_preds (list): detection tile predictions.
batch_tile_attrs (list): detection tile attributes.
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = self.explain_mode
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
batch_size = len(tile_attrs)
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_s_map, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
tile_preds.bboxes,
tile_preds.labels,
tile_preds.scores,
saliency_maps,
feature_vectors,
strict=True,
):
offset_x, offset_y, _, _ = tile_attr["roi"]
tile_bboxes[:, 0::2] += offset_x
tile_bboxes[:, 1::2] += offset_y
tile_id = tile_attr["tile_id"]
if tile_id not in img_ids:
img_ids.append(tile_id)
tile_img_info.padding = tile_attr["roi"]
det_pred_entity = DetPredEntity(
image=torch.empty(tile_img_info.ori_shape),
img_info=tile_img_info,
bboxes=tile_bboxes,
labels=tile_labels,
score=tile_scores,
)
if explain_mode:
det_pred_entity.feature_vector = tile_f_vect
det_pred_entity.saliency_map = tile_s_map
entities_to_merge[tile_id].append(det_pred_entity)
return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[DetPredEntity],
explain_mode: bool = False,
) -> DetPredEntity:
"""Merge tile predictions to one single prediction.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[DetPredEntity]): List of tile prediction entities.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
Returns:
DetPredEntity: Merged prediction entity.
"""
bboxes: list | torch.Tensor = []
labels: list | torch.Tensor = []
scores: list | torch.Tensor = []
feature_vectors = []
saliency_maps = []
tiles_coords = []
img_size = img_info.ori_shape
for tile_entity in entities:
num_preds = len(tile_entity.bboxes)
if num_preds > 0:
bboxes.extend(tile_entity.bboxes)
labels.extend(tile_entity.labels)
scores.extend(tile_entity.score)
if explain_mode:
tiles_coords.append(tile_entity.img_info.padding)
feature_vectors.append(tile_entity.feature_vector)
saliency_maps.append(tile_entity.saliency_map)
bboxes = torch.stack(bboxes) if len(bboxes) > 0 else torch.empty((0, 4), device=img_info.device)
labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device)
scores = torch.stack(scores) if len(scores) > 0 else torch.empty((0,), device=img_info.device)
bboxes, labels, scores, _ = self.nms_postprocess(bboxes, scores, labels)
det_pred_entity = DetPredEntity(
image=torch.empty(img_size),
img_info=img_info,
score=scores,
bboxes=tv_tensors.BoundingBoxes(bboxes, canvas_size=img_size, format="XYXY"),
labels=labels,
)
if explain_mode:
det_pred_entity.feature_vector = np.mean(feature_vectors, axis=0)
det_pred_entity.saliency_map = self._merge_saliency_maps(saliency_maps, img_size, tiles_coords)
return det_pred_entity
def _merge_saliency_maps(
self,
saliency_maps: list[np.array],
shape: tuple[int, int],
tiles_coords: list[tuple[int, int, int, int]],
) -> np.ndarray:
"""Merging saliency maps from each tile for PyTorch implementation.
OV implementation is on ModelAPI side.
Args:
saliency_maps: list of saliency maps, shape of each map is (Nc, H, W)
shape: shape of the original image
tiles_coords: coordinates of tiles
Returns:
Merged saliency map with shape (Nc, H, W)
"""
if len(saliency_maps) == 1:
return saliency_maps[0]
image_saliency_map = saliency_maps[0]
if len(image_saliency_map.shape) == 1:
return image_saliency_map
num_classes = saliency_maps[0].shape[0]
map_h, map_w = saliency_maps[0].shape[1:]
image_h, image_w = shape
ratio = map_h / min(image_h, self.tile_size[0]), map_w / min(image_w, self.tile_size[1])
image_map_h = int(image_h * ratio[0])
image_map_w = int(image_w * ratio[1])
merged_map = np.zeros((num_classes, image_map_h, image_map_w))
# Note: Skip the first saliency map as it is the full image value.
saliency_maps, start_idx = (saliency_maps[1:], 1) if self.with_full_img else (saliency_maps, 0)
for i, saliency_map in enumerate(saliency_maps, start_idx):
for class_idx in range(num_classes):
cls_map = saliency_map[class_idx]
x_1, y_1, map_w, map_h = tiles_coords[i]
x_2, y_2 = x_1 + map_w, y_1 + map_h
y_1, x_1 = int(y_1 * ratio[0]), int(x_1 * ratio[1])
y_2, x_2 = int(y_2 * ratio[0]), int(x_2 * ratio[1])
map_h, map_w = cls_map.shape
if (map_h > y_2 - y_1 > 0) and (map_w > x_2 - x_1 > 0):
cls_map = cv2.resize(cls_map, (x_2 - x_1, y_2 - y_1))
map_h, map_w = y_2 - y_1, x_2 - x_1
for hi, wi in [(h_, w_) for h_ in range(map_h) for w_ in range(map_w)]:
map_pixel = cls_map[hi, wi]
merged_pixel = merged_map[class_idx][y_1 + hi, x_1 + wi]
if merged_pixel != 0:
merged_map[class_idx][y_1 + hi, x_1 + wi] = 0.5 * (map_pixel + merged_pixel)
else:
merged_map[class_idx][y_1 + hi, x_1 + wi] = map_pixel
for class_idx in range(num_classes):
if self.with_full_img:
image_map_cls = image_saliency_map[class_idx]
image_map_cls = cv2.resize(image_map_cls, (image_map_w, image_map_h))
merged_map[class_idx] += 0.5 * image_map_cls
merged_map[class_idx] = _non_linear_normalization(merged_map[class_idx])
return merged_map.astype(np.uint8)
def _non_linear_normalization(saliency_map: np.ndarray) -> np.ndarray:
"""Use non-linear normalization y=x**1.5 for 2D saliency maps."""
min_soft_score = np.min(saliency_map)
# Make merged_map distribution positive to perform non-linear normalization y=x**1.5
saliency_map = (saliency_map - min_soft_score) ** 1.5
max_soft_score = np.max(saliency_map)
saliency_map = 255.0 / (max_soft_score + 1e-12) * saliency_map
return np.floor(saliency_map)
[docs]
class InstanceSegTileMerge(TileMerge):
"""Instance segmentation tile merge."""
[docs]
def merge(
self,
batch_tile_preds: list[InstanceSegBatchPredEntity],
batch_tile_attrs: list[list[dict]],
) -> list[InstanceSegPredEntity]:
"""Merge inst-seg tile predictions to one single prediction.
Args:
batch_tile_preds (list): instance-seg tile predictions.
batch_tile_attrs (list): instance-seg tile attributes.
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = self.explain_mode
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_attrs))]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
tile_preds.bboxes,
tile_preds.labels,
tile_preds.scores,
tile_preds.masks,
feature_vectors,
strict=True,
):
if int_max_check_condition(tile_masks):
keep_indices = keep_chunkify(tile_masks)
else:
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
keep_indices = keep_indices.nonzero(as_tuple=True)[0]
_bboxes = tile_bboxes[keep_indices]
_labels = tile_labels[keep_indices]
_scores = tile_scores[keep_indices]
_masks = tile_masks[keep_indices]
offset_x, offset_y, _, _ = tile_attr["roi"]
_bboxes[:, 0::2] += offset_x
_bboxes[:, 1::2] += offset_y
tile_id = tile_attr["tile_id"]
if tile_id not in img_ids:
img_ids.append(tile_id)
tile_img_info.padding = tile_attr["roi"]
inst_seg_pred_entity = InstanceSegPredEntity(
image=torch.empty(tile_img_info.ori_shape),
img_info=tile_img_info,
bboxes=_bboxes,
labels=_labels,
score=_scores,
masks=_masks.to_sparse(),
polygons=[],
)
if explain_mode:
inst_seg_pred_entity.feature_vector = tile_f_vect
inst_seg_pred_entity.saliency_map = []
entities_to_merge[tile_id].append(inst_seg_pred_entity)
return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[InstanceSegPredEntity],
explain_mode: bool = False,
) -> InstanceSegPredEntity:
"""Merge tile predictions to one single prediction.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[InstanceSegPredEntity]): List of tile prediction entities.
Returns:
InstanceSegPredEntity: Merged prediction entity.
"""
bboxes: list | torch.Tensor = []
labels: list | torch.Tensor = []
scores: list | torch.Tensor = []
masks: list | torch.Tensor = []
feature_vectors = []
img_size = img_info.ori_shape
for tile_entity in entities:
num_preds = len(tile_entity.bboxes)
if num_preds > 0:
bboxes.extend(tile_entity.bboxes)
labels.extend(tile_entity.labels)
scores.extend(tile_entity.score)
offset_x, offset_y, _, _ = tile_entity.img_info.padding
mask_indices = tile_entity.masks.indices()
mask_values = tile_entity.masks.values()
mask_indices[1] += offset_y
mask_indices[2] += offset_x
masks.extend(
torch.sparse_coo_tensor(mask_indices, mask_values, (num_preds, *img_size)),
)
if explain_mode:
feature_vectors.append(tile_entity.feature_vector)
bboxes = torch.stack(bboxes) if len(bboxes) > 0 else torch.empty((0, 4), device=img_info.device)
labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device)
scores = torch.stack(scores) if len(scores) > 0 else torch.empty((0,), device=img_info.device)
masks = masks if len(masks) > 0 else torch.empty((0, *img_size))
bboxes, labels, scores, masks = self.nms_postprocess(bboxes, scores, labels, masks)
inst_seg_pred_entity = InstanceSegPredEntity(
image=torch.empty(img_size),
img_info=img_info,
score=scores,
bboxes=tv_tensors.BoundingBoxes(bboxes, canvas_size=img_size, format="XYXY"),
labels=labels,
masks=tv_tensors.Mask(masks, dtype=bool),
polygons=[],
)
if explain_mode:
inst_seg_pred_entity.feature_vector = np.mean(feature_vectors, axis=0)
inst_seg_pred_entity.saliency_map = self.get_saliency_maps_from_masks(
labels,
scores,
masks,
self.num_classes,
)
return inst_seg_pred_entity
[docs]
def get_saliency_maps_from_masks(
self,
labels: torch.Tensor,
scores: torch.Tensor,
masks: None | torch.Tensor,
num_classes: int,
) -> np.ndarray:
"""Average and normalize predicted masks in per-class.
Returns:
np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W]
"""
if masks is None:
return np.ndarray([])
pred = {"labels": labels, "scores": scores, "masks": masks}
return InstSegExplainAlgo.average_and_normalize(pred, num_classes)
[docs]
class SegmentationTileMerge(TileMerge):
"""Semantic segmentation tile merge."""
def __init__(
self,
img_infos: list[ImageInfo],
num_classes: int,
tile_config: TileConfig,
explain_mode: bool = False,
) -> None:
super().__init__(img_infos, num_classes, tile_config, explain_mode)
if explain_mode:
msg = "Explain mode is not supported for segmentation"
raise ValueError(msg)
[docs]
def merge(
self,
batch_tile_preds: list[TorchPredBatch],
batch_tile_attrs: list[list[dict]],
) -> list[TorchPredItem]:
"""Merge batch tile predictions to a list of full-size prediction data entities.
Args:
batch_tile_preds (list[SegBatchPredEntity]): segmentation tile predictions.
batch_tile_attrs (list[list[dict]]): segmentation tile attributes.
Returns:
list[TorchPredItem]: List of full-size prediction data entities after merging.
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = self.explain_mode
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
batch_size = tile_preds.batch_size
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
if saliency_maps is None or feature_vectors is None:
msg = "The saliency maps or feature vectors are not provided."
raise ValueError(msg)
if tile_preds.imgs_info is None:
msg = "Image information is not provided."
raise ValueError(msg)
if tile_preds.masks is None:
msg = "The predicted masks are not provided."
raise ValueError(msg)
for tile_attr, tile_img_info, tile_masks, tile_s_map, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
tile_preds.masks,
saliency_maps,
feature_vectors,
):
if tile_img_info is None:
msg = f"Image information is not provided : {tile_preds.imgs_info}."
raise ValueError(msg)
tile_id = tile_attr["tile_id"]
if tile_id not in img_ids:
img_ids.append(tile_id)
tile_img_info.padding = tile_attr["roi"]
seg_pred_entity = TorchPredItem(
image=torch.empty((3, *tile_img_info.ori_shape)),
img_info=tile_img_info,
masks=tv_tensors.Mask(tile_masks),
scores=torch.tensor([]),
)
if explain_mode:
seg_pred_entity.feature_vector = tile_f_vect
seg_pred_entity.saliency_map = tile_s_map
entities_to_merge[tile_id].append(seg_pred_entity)
return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos)
]
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[TorchPredItem],
explain_mode: bool = False,
) -> TorchPredItem:
"""Merge tile predictions to one single prediction.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[TorchPredItem]): List of tile prediction entities.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
Returns:
TorchPredItem: Merged prediction entity.
"""
img_size = img_info.ori_shape
if any(entity is None for entity in entities):
msg = f"Some entities are None: {entities}."
raise ValueError(msg)
if entities[0].masks is None:
msg = "The predicted masks are not provided."
raise ValueError(msg)
num_classes = len(entities[0].masks)
# Create a vote map for overlapping tiles
vote_mask = torch.zeros(img_size, dtype=torch.int, device=img_info.device)
full_logits_mask = torch.zeros((num_classes, *img_size), device=img_info.device)
for tile_entity in entities:
if tile_entity.img_info is None:
msg = "Image information is not provided."
raise ValueError(msg)
if tile_entity.masks is None:
msg = "The predicted masks are not provided."
raise ValueError(msg)
offset_x, offset_y, tile_w, tile_h = tile_entity.img_info.padding
vote_mask[offset_y : offset_y + tile_h, offset_x : offset_x + tile_w] += 1
full_logits_mask[:, offset_y : offset_y + tile_h, offset_x : offset_x + tile_w] += tile_entity.masks[
:,
:tile_h,
:tile_w,
]
full_logits_mask = full_logits_mask / vote_mask.unsqueeze(0)
return TorchPredItem(
image=torch.empty((3, *img_size)),
img_info=img_info,
masks=tv_tensors.Mask(full_logits_mask.argmax(0).unsqueeze(0)),
scores=torch.tensor([]),
)