# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""OTX tile merge module."""
from __future__ import annotations
from abc import abstractmethod
from collections import defaultdict
from typing import Generic
import cv2
import numpy as np
import torch
from torchvision import tv_tensors
from torchvision.ops import batched_nms
from otx.algo.explain.explain_algo import InstSegExplainAlgo
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity
from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegPredEntity
from otx.core.data.entity.segmentation import SegBatchPredEntity, SegPredEntity
[docs]
class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
"""Base class for tile merge.
Args:
img_infos (list[ImageInfo]): Original image information before tiling.
num_classes (int): Number of classes.
tile_config (TileConfig): Tile configuration.
explain_mode (bool, optional): Whether or not tiles have explain features. Default: False.
"""
def __init__(
self,
img_infos: list[ImageInfo],
num_classes: int,
tile_config: TileConfig,
explain_mode: bool = False,
) -> None:
self.img_infos = img_infos
self.num_classes = num_classes
self.tile_size = tile_config.tile_size
self.iou_threshold = tile_config.iou_threshold
self.max_num_instances = tile_config.max_num_instances
self.with_full_img = tile_config.with_full_img
self.explain_mode = explain_mode
@abstractmethod
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[T_OTXDataEntity],
explain_mode: bool = False,
) -> T_OTXDataEntity:
"""Merge tile predictions to one single full-size prediction data entity.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[T_OTXDataEntity]): List of tile prediction entities.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
Returns:
T_OTXDataEntity: Merged prediction entity.
"""
raise NotImplementedError
[docs]
@abstractmethod
def merge(
self,
batch_tile_preds: list[T_OTXBatchPredEntity],
batch_tile_attrs: list[list[dict]],
) -> list[T_OTXDataEntity]:
"""Merge batch tile predictions to a list of full-size prediction data entities.
Args:
batch_tile_preds (list): list of tile predictions.
batch_tile_attrs (list): list of tile attributes.
"""
raise NotImplementedError
[docs]
def nms_postprocess(
self,
bboxes: torch.Tensor,
scores: torch.Tensor,
labels: torch.Tensor,
masks: None | list[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None | torch.Tensor]:
"""Non-maximum suppression and post-process."""
keep = batched_nms(bboxes, scores, labels, self.iou_threshold)
if len(keep) > self.max_num_instances:
keep = keep[: self.max_num_instances]
bboxes = bboxes[keep]
labels = labels[keep]
scores = scores[keep]
if masks is not None and len(masks) > 0:
# coalesce sparse tensors to prevent them from growing too large.
masks = torch.stack([masks[idx] for idx in keep]).coalesce().to_dense()
return bboxes, labels, scores, masks
[docs]
class DetectionTileMerge(TileMerge):
"""Detection tile merge."""
[docs]
def merge(
self,
batch_tile_preds: list[DetBatchPredEntity],
batch_tile_attrs: list[list[dict]],
) -> list[DetPredEntity]:
"""Merge batch tile predictions to a list of full-size prediction data entities.
Args:
batch_tile_preds (list): detection tile predictions.
batch_tile_attrs (list): detection tile attributes.
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = self.explain_mode
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
batch_size = len(tile_attrs)
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_s_map, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
tile_preds.bboxes,
tile_preds.labels,
tile_preds.scores,
saliency_maps,
feature_vectors,
strict=True,
):
offset_x, offset_y, _, _ = tile_attr["roi"]
tile_bboxes[:, 0::2] += offset_x
tile_bboxes[:, 1::2] += offset_y
tile_id = tile_attr["tile_id"]
if tile_id not in img_ids:
img_ids.append(tile_id)
tile_img_info.padding = tile_attr["roi"]
det_pred_entity = DetPredEntity(
image=torch.empty(tile_img_info.ori_shape),
img_info=tile_img_info,
bboxes=tile_bboxes,
labels=tile_labels,
score=tile_scores,
)
if explain_mode:
det_pred_entity.feature_vector = tile_f_vect
det_pred_entity.saliency_map = tile_s_map
entities_to_merge[tile_id].append(det_pred_entity)
return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[DetPredEntity],
explain_mode: bool = False,
) -> DetPredEntity:
"""Merge tile predictions to one single prediction.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[DetPredEntity]): List of tile prediction entities.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
Returns:
DetPredEntity: Merged prediction entity.
"""
bboxes: list | torch.Tensor = []
labels: list | torch.Tensor = []
scores: list | torch.Tensor = []
feature_vectors = []
saliency_maps = []
tiles_coords = []
img_size = img_info.ori_shape
for tile_entity in entities:
num_preds = len(tile_entity.bboxes)
if num_preds > 0:
bboxes.extend(tile_entity.bboxes)
labels.extend(tile_entity.labels)
scores.extend(tile_entity.score)
if explain_mode:
tiles_coords.append(tile_entity.img_info.padding)
feature_vectors.append(tile_entity.feature_vector)
saliency_maps.append(tile_entity.saliency_map)
bboxes = torch.stack(bboxes) if len(bboxes) > 0 else torch.empty((0, 4), device=img_info.device)
labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device)
scores = torch.stack(scores) if len(scores) > 0 else torch.empty((0,), device=img_info.device)
bboxes, labels, scores, _ = self.nms_postprocess(bboxes, scores, labels)
det_pred_entity = DetPredEntity(
image=torch.empty(img_size),
img_info=img_info,
score=scores,
bboxes=tv_tensors.BoundingBoxes(bboxes, canvas_size=img_size, format="XYXY"),
labels=labels,
)
if explain_mode:
det_pred_entity.feature_vector = np.mean(feature_vectors, axis=0)
det_pred_entity.saliency_map = self._merge_saliency_maps(saliency_maps, img_size, tiles_coords)
return det_pred_entity
def _merge_saliency_maps(
self,
saliency_maps: list[np.array],
shape: tuple[int, int],
tiles_coords: list[tuple[int, int, int, int]],
) -> np.ndarray:
"""Merging saliency maps from each tile for PyTorch implementation.
OV implementation is on ModelAPI side.
Args:
saliency_maps: list of saliency maps, shape of each map is (Nc, H, W)
shape: shape of the original image
tiles_coords: coordinates of tiles
Returns:
Merged saliency map with shape (Nc, H, W)
"""
if len(saliency_maps) == 1:
return saliency_maps[0]
image_saliency_map = saliency_maps[0]
if len(image_saliency_map.shape) == 1:
return image_saliency_map
num_classes = saliency_maps[0].shape[0]
map_h, map_w = saliency_maps[0].shape[1:]
image_h, image_w = shape
ratio = map_h / min(image_h, self.tile_size[0]), map_w / min(image_w, self.tile_size[1])
image_map_h = int(image_h * ratio[0])
image_map_w = int(image_w * ratio[1])
merged_map = np.zeros((num_classes, image_map_h, image_map_w))
# Note: Skip the first saliency map as it is the full image value.
saliency_maps, start_idx = (saliency_maps[1:], 1) if self.with_full_img else (saliency_maps, 0)
for i, saliency_map in enumerate(saliency_maps, start_idx):
for class_idx in range(num_classes):
cls_map = saliency_map[class_idx]
x_1, y_1, map_w, map_h = tiles_coords[i]
x_2, y_2 = x_1 + map_w, y_1 + map_h
y_1, x_1 = int(y_1 * ratio[0]), int(x_1 * ratio[1])
y_2, x_2 = int(y_2 * ratio[0]), int(x_2 * ratio[1])
map_h, map_w = cls_map.shape
if (map_h > y_2 - y_1 > 0) and (map_w > x_2 - x_1 > 0):
cls_map = cv2.resize(cls_map, (x_2 - x_1, y_2 - y_1))
map_h, map_w = y_2 - y_1, x_2 - x_1
for hi, wi in [(h_, w_) for h_ in range(map_h) for w_ in range(map_w)]:
map_pixel = cls_map[hi, wi]
merged_pixel = merged_map[class_idx][y_1 + hi, x_1 + wi]
if merged_pixel != 0:
merged_map[class_idx][y_1 + hi, x_1 + wi] = 0.5 * (map_pixel + merged_pixel)
else:
merged_map[class_idx][y_1 + hi, x_1 + wi] = map_pixel
for class_idx in range(num_classes):
if self.with_full_img:
image_map_cls = image_saliency_map[class_idx]
image_map_cls = cv2.resize(image_map_cls, (image_map_w, image_map_h))
merged_map[class_idx] += 0.5 * image_map_cls
merged_map[class_idx] = _non_linear_normalization(merged_map[class_idx])
return merged_map.astype(np.uint8)
def _non_linear_normalization(saliency_map: np.ndarray) -> np.ndarray:
"""Use non-linear normalization y=x**1.5 for 2D saliency maps."""
min_soft_score = np.min(saliency_map)
# Make merged_map distribution positive to perform non-linear normalization y=x**1.5
saliency_map = (saliency_map - min_soft_score) ** 1.5
max_soft_score = np.max(saliency_map)
saliency_map = 255.0 / (max_soft_score + 1e-12) * saliency_map
return np.floor(saliency_map)
[docs]
class InstanceSegTileMerge(TileMerge):
"""Instance segmentation tile merge."""
[docs]
def merge(
self,
batch_tile_preds: list[InstanceSegBatchPredEntity],
batch_tile_attrs: list[list[dict]],
) -> list[InstanceSegPredEntity]:
"""Merge inst-seg tile predictions to one single prediction.
Args:
batch_tile_preds (list): instance-seg tile predictions.
batch_tile_attrs (list): instance-seg tile attributes.
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = self.explain_mode
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_attrs))]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
tile_preds.bboxes,
tile_preds.labels,
tile_preds.scores,
tile_preds.masks,
feature_vectors,
strict=True,
):
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
keep_indices = keep_indices.nonzero(as_tuple=True)[0]
_bboxes = tile_bboxes[keep_indices]
_labels = tile_labels[keep_indices]
_scores = tile_scores[keep_indices]
_masks = tile_masks[keep_indices]
offset_x, offset_y, _, _ = tile_attr["roi"]
_bboxes[:, 0::2] += offset_x
_bboxes[:, 1::2] += offset_y
tile_id = tile_attr["tile_id"]
if tile_id not in img_ids:
img_ids.append(tile_id)
tile_img_info.padding = tile_attr["roi"]
inst_seg_pred_entity = InstanceSegPredEntity(
image=torch.empty(tile_img_info.ori_shape),
img_info=tile_img_info,
bboxes=_bboxes,
labels=_labels,
score=_scores,
masks=_masks.to_sparse(),
polygons=[],
)
if explain_mode:
inst_seg_pred_entity.feature_vector = tile_f_vect
inst_seg_pred_entity.saliency_map = []
entities_to_merge[tile_id].append(inst_seg_pred_entity)
return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[InstanceSegPredEntity],
explain_mode: bool = False,
) -> InstanceSegPredEntity:
"""Merge tile predictions to one single prediction.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[InstanceSegPredEntity]): List of tile prediction entities.
Returns:
InstanceSegPredEntity: Merged prediction entity.
"""
bboxes: list | torch.Tensor = []
labels: list | torch.Tensor = []
scores: list | torch.Tensor = []
masks: list | torch.Tensor = []
feature_vectors = []
img_size = img_info.ori_shape
for tile_entity in entities:
num_preds = len(tile_entity.bboxes)
if num_preds > 0:
bboxes.extend(tile_entity.bboxes)
labels.extend(tile_entity.labels)
scores.extend(tile_entity.score)
offset_x, offset_y, _, _ = tile_entity.img_info.padding
mask_indices = tile_entity.masks.indices()
mask_values = tile_entity.masks.values()
mask_indices[1] += offset_y
mask_indices[2] += offset_x
masks.extend(
torch.sparse_coo_tensor(mask_indices, mask_values, (num_preds, *img_size)),
)
if explain_mode:
feature_vectors.append(tile_entity.feature_vector)
bboxes = torch.stack(bboxes) if len(bboxes) > 0 else torch.empty((0, 4), device=img_info.device)
labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device)
scores = torch.stack(scores) if len(scores) > 0 else torch.empty((0,), device=img_info.device)
masks = masks if len(masks) > 0 else torch.empty((0, *img_size))
bboxes, labels, scores, masks = self.nms_postprocess(bboxes, scores, labels, masks)
inst_seg_pred_entity = InstanceSegPredEntity(
image=torch.empty(img_size),
img_info=img_info,
score=scores,
bboxes=tv_tensors.BoundingBoxes(bboxes, canvas_size=img_size, format="XYXY"),
labels=labels,
masks=tv_tensors.Mask(masks, dtype=bool),
polygons=[],
)
if explain_mode:
inst_seg_pred_entity.feature_vector = np.mean(feature_vectors, axis=0)
inst_seg_pred_entity.saliency_map = self.get_saliency_maps_from_masks(
labels,
scores,
masks,
self.num_classes,
)
return inst_seg_pred_entity
[docs]
def get_saliency_maps_from_masks(
self,
labels: torch.Tensor,
scores: torch.Tensor,
masks: None | torch.Tensor,
num_classes: int,
) -> np.ndarray:
"""Average and normalize predicted masks in per-class.
Returns:
np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W]
"""
if masks is None:
return np.ndarray([])
pred = {"labels": labels, "scores": scores, "masks": masks}
return InstSegExplainAlgo.average_and_normalize(pred, num_classes)
[docs]
class SegmentationTileMerge(TileMerge):
"""Semantic segmentation tile merge."""
def __init__(
self,
img_infos: list[ImageInfo],
num_classes: int,
tile_config: TileConfig,
explain_mode: bool = False,
) -> None:
super().__init__(img_infos, num_classes, tile_config, explain_mode)
if explain_mode:
msg = "Explain mode is not supported for segmentation"
raise ValueError(msg)
[docs]
def merge(
self,
batch_tile_preds: list[SegBatchPredEntity],
batch_tile_attrs: list[list[dict]],
) -> list[SegPredEntity]:
"""Merge batch tile predictions to a list of full-size prediction data entities.
Args:
batch_tile_preds (list[SegBatchPredEntity]): segmentation tile predictions.
batch_tile_attrs (list[list[dict]]): segmentation tile attributes.
Returns:
list[SegPredEntity]: List of full-size prediction data entities after merging.
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = self.explain_mode
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
batch_size = tile_preds.batch_size
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
for tile_attr, tile_img_info, tile_masks, tile_s_map, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
tile_preds.masks,
saliency_maps,
feature_vectors,
):
tile_id = tile_attr["tile_id"]
if tile_id not in img_ids:
img_ids.append(tile_id)
tile_img_info.padding = tile_attr["roi"]
seg_pred_entity = SegPredEntity(
image=torch.empty(tile_img_info.ori_shape),
img_info=tile_img_info,
masks=tile_masks,
score=[],
)
if explain_mode:
seg_pred_entity.feature_vector = tile_f_vect
seg_pred_entity.saliency_map = tile_s_map
entities_to_merge[tile_id].append(seg_pred_entity)
return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos)
]
def _merge_entities(
self,
img_info: ImageInfo,
entities: list[SegPredEntity],
explain_mode: bool = False,
) -> SegPredEntity:
"""Merge tile predictions to one single prediction.
Args:
img_info (ImageInfo): Image information about the original image before tiling.
entities (list[SegPredEntity]): List of tile prediction entities.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
Returns:
SegPredEntity: Merged prediction entity.
"""
img_size = img_info.ori_shape
num_classes = len(entities[0].masks)
# Create a vote map for overlapping tiles
vote_mask = torch.zeros(img_size, dtype=torch.int, device=img_info.device)
full_logits_mask = torch.zeros((num_classes, *img_size), device=img_info.device)
for tile_entity in entities:
offset_x, offset_y, tile_w, tile_h = tile_entity.img_info.padding
vote_mask[offset_y : offset_y + tile_h, offset_x : offset_x + tile_w] += 1
full_logits_mask[:, offset_y : offset_y + tile_h, offset_x : offset_x + tile_w] += tile_entity.masks[
:,
:tile_h,
:tile_w,
]
full_logits_mask = full_logits_mask / vote_mask.unsqueeze(0)
return SegPredEntity(
image=torch.empty(img_size),
img_info=img_info,
masks=full_logits_mask.argmax(0).unsqueeze(0),
score=[],
)