"""Collection of utils for data in Detection Task."""
# Copyright (C) 2021-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import json
import os.path as osp
from typing import Any, Dict, List, Optional, Sequence
import numpy as np
from mmdet.datasets.api_wrappers.coco_api import COCO
from otx.algorithms.common.utils.data import compute_robust_dataset_statistics
from otx.algorithms.detection.configs.base.configuration import DetectionConfig
from otx.api.entities.annotation import (
Annotation,
AnnotationSceneEntity,
AnnotationSceneKind,
)
from otx.api.entities.dataset_item import DatasetItemEntity
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.id import ID
from otx.api.entities.image import Image
from otx.api.entities.label import Domain, LabelEntity
from otx.api.entities.scored_label import ScoredLabel
from otx.api.entities.shapes.polygon import Point, Polygon
from otx.api.entities.shapes.rectangle import Rectangle
from otx.api.entities.subset import Subset
from otx.api.utils.shape_factory import ShapeFactory
from otx.utils.logger import get_logger
logger = get_logger()
# pylint: disable=too-many-instance-attributes, too-many-arguments
def get_classes_from_annotation(path):
"""Return classes from annotation."""
with open(path, encoding="UTF-8") as read_file:
content = json.load(read_file)
categories = [v["name"] for v in sorted(content["categories"], key=lambda x: x["id"])]
return categories
class LoadAnnotations:
"""Load Annotations class."""
def __init__(self, with_bbox: bool = True, with_label: bool = True, with_mask: bool = False):
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
@staticmethod
def _load_bboxes(results):
ann_info = results["ann_info"]
results["gt_bboxes"] = ann_info["bboxes"].copy()
gt_bboxes_ignore = ann_info.get("bboxes_ignore", None)
if gt_bboxes_ignore is not None:
results["gt_bboxes_ignore"] = gt_bboxes_ignore.copy()
results["bbox_fields"].append("gt_bboxes_ignore")
results["bbox_fields"].append("gt_bboxes")
return results
@staticmethod
def _load_labels(results):
results["gt_labels"] = results["ann_info"]["labels"].copy()
return results
@staticmethod
def _load_masks(results):
gt_masks = results["ann_info"]["masks"]
results["gt_masks"] = gt_masks
results["mask_fields"].append("gt_masks")
return results
def __call__(self, results: Dict[str, Any]):
"""Callback function of LoadAnnotations."""
if self.with_bbox:
results = LoadAnnotations._load_bboxes(results)
if results is None:
return None
if self.with_label:
results = LoadAnnotations._load_labels(results)
if self.with_mask:
results = LoadAnnotations._load_masks(results)
return results
def __repr__(self):
"""String function of LoadAnnotations."""
repr_str = self.__class__.__name__
repr_str += f"(with_bbox={self.with_bbox}, "
repr_str += f"with_label={self.with_label})"
return repr_str
class CocoDataset:
"""CocoDataset."""
def __init__(
self,
ann_file: str,
classes: Optional[Sequence[str]] = None,
data_root: Optional[str] = None,
img_prefix: str = "",
test_mode: bool = False,
filter_empty_gt: bool = True,
min_size: Optional[int] = None,
with_mask: bool = False,
):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
self.classes = self.get_classes(classes)
self.min_size = min_size
self.with_mask = with_mask
if self.data_root is not None:
# if not osp.isabs(self.ann_file):
# self.ann_file = osp.join(self.data_root, self.ann_file)
if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
self.img_prefix = osp.join(self.data_root, self.img_prefix)
self.data_infos = self.load_annotations(self.ann_file)
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
def __len__(self):
"""Length of CocoDataset."""
return len(self.data_infos)
def pre_pipeline(self, results: Dict[str, Any]):
"""Initialize pipeline."""
results["img_prefix"] = self.img_prefix
results["bbox_fields"] = []
results["mask_fields"] = []
results["seg_fields"] = []
def _rand_another(self, idx):
"""Get Random indices."""
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def __getitem__(self, idx: int):
"""Return dataset item from index."""
return self.prepare_img(idx)
def __iter__(self):
"""Iterator of CocoDataset."""
for i in range(len(self)):
yield self[i]
def prepare_img(self, idx: int):
"""Load Annotations function with images."""
img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return LoadAnnotations(with_mask=self.with_mask)(results)
def get_classes(self, classes: Optional[Sequence[str]] = None):
"""Return classes function."""
if classes is None:
return get_classes_from_annotation(self.ann_file)
if isinstance(classes, (tuple, list)):
return classes
raise ValueError(f"Unsupported type {type(classes)} of classes.")
def load_annotations(self, ann_file):
"""Load annotations function from coco."""
self.coco = COCO(ann_file)
self.cat_ids = self.coco.get_cat_ids(cat_names=self.classes)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.get_img_ids()
data_infos = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info["filename"] = info["file_name"]
data_infos.append(info)
return data_infos
def get_ann_info(self, idx: int):
"""Getting Annotation info."""
img_id = self.data_infos[idx]["id"]
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return self._parse_ann_info(self.data_infos[idx], ann_info)
def get_cat_ids(self, idx: int):
"""Getting cat_ids."""
img_id = self.data_infos[idx]["id"]
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return [ann["category_id"] for ann in ann_info]
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_["image_id"] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = self.img_ids[i]
if self.filter_empty_gt and img_id not in ids_in_cat:
continue
if min(img_info["width"], img_info["height"]) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds
def _parse_ann_info(self, img_info, ann_info): # pylint: disable=too-many-locals, too-many-branches
"""Parse annotation info."""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ann = []
for ann in ann_info:
if ann.get("ignore", False):
continue
x1, y1, width, height = ann["bbox"]
inter_w = max(0, min(x1 + width, img_info["width"]) - max(x1, 0))
inter_h = max(0, min(y1 + height, img_info["height"]) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann["area"] <= 0 or width < 1 or height < 1:
continue
if self.min_size is not None:
if width < self.min_size or height < self.min_size:
continue
if ann["category_id"] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + width, y1 + height]
if ann.get("iscrowd", False):
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann["category_id"]])
gt_masks_ann.append(ann.get("segmentation", None))
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
seg_map = img_info["filename"].replace("jpg", "png")
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
bboxes_ignore=gt_bboxes_ignore,
masks=gt_masks_ann,
seg_map=seg_map,
)
return ann
def find_label_by_name(labels: List[LabelEntity], name: str, domain: Domain):
"""Return label from name."""
matching_labels = [label for label in labels if label.name == name]
if len(matching_labels) == 1:
return matching_labels[0]
if len(matching_labels) == 0:
label = LabelEntity(name=name, domain=domain, id=ID(len(labels)))
labels.append(label)
return label
raise ValueError("Found multiple matching labels")
[docs]
def get_sizes_from_dataset_entity(dataset: DatasetEntity, target_wh: List[int]):
"""Function to get sizes of instances in DatasetEntity and to resize it to the target size.
:param dataset: DatasetEntity in which to get statistics
:param target_wh: target width and height of the dataset
:return list: tuples with width and height of each instance
"""
wh_stats = []
for item in dataset:
for ann in item.get_annotations(include_empty=False):
has_detection_labels = any(
label.domain == Domain.DETECTION for label in ann.get_labels(include_empty=False)
)
if has_detection_labels:
box = ShapeFactory.shape_as_rectangle(ann.shape)
width = box.width * target_wh[0]
height = box.height * target_wh[1]
wh_stats.append((width, height))
return wh_stats
[docs]
def get_anchor_boxes(wh_stats: List[tuple], group_as: List[int]):
"""Get anchor box widths & heights."""
from sklearn.cluster import KMeans
kmeans = KMeans(init="k-means++", n_clusters=sum(group_as), random_state=0).fit(wh_stats)
centers = kmeans.cluster_centers_
areas = np.sqrt(np.prod(centers, axis=1))
idx = np.argsort(areas)
widths = centers[idx, 0]
heights = centers[idx, 1]
group_as = np.cumsum(group_as[:-1])
widths, heights = np.split(widths, group_as), np.split(heights, group_as)
widths = [width.tolist() for width in widths]
heights = [height.tolist() for height in heights]
return widths, heights
def adaptive_tile_params(
tiling_parameters: DetectionConfig.BaseTilingParameters, dataset: DatasetEntity, object_tile_ratio=0.03, rule="avg"
):
"""Config tile parameters.
Adapt based on annotation statistics.
i.e. tile size, tile overlap, ratio and max objects per sample
Args:
tiling_parameters (BaseTilingParameters): tiling parameters of the model
dataset (DatasetEntity): training dataset
object_tile_ratio (float, optional): The desired ratio of object size and tile size. Defaults to 16/512=0.03.
rule (str, optional): min or avg. In `min` mode, tile size is computed based on the smallest object, and in
`avg` mode tile size is computed by averaging all the object areas. Defaults to "avg".
"""
assert rule in ["min", "avg"], f"Unknown rule: {rule}"
stat = compute_robust_dataset_statistics(dataset, ann_stat=True)
max_num_objects = round(stat["annotation"]["num_per_image"]["max"])
avg_size = stat["annotation"]["size_of_shape"]["avg"]
min_size = stat["annotation"]["size_of_shape"]["robust_min"]
max_size = stat["annotation"]["size_of_shape"]["robust_max"]
logger.info(f"----> [stat] scale avg: {avg_size}")
logger.info(f"----> [stat] scale min: {min_size}")
logger.info(f"----> [stat] scale max: {max_size}")
if rule == "min":
object_size = min_size
elif rule == "avg":
object_size = avg_size
logger.info("[Adaptive tiling pararms]")
object_tile_ratio = tiling_parameters.object_tile_ratio
tile_size = int(object_size / object_tile_ratio)
tile_overlap = max_size / tile_size
logger.info(f"----> {rule}_object_size: {object_size}")
logger.info(f"----> max_object_size: {max_size}")
logger.info(f"----> object_tile_ratio: {object_tile_ratio}")
logger.info(f"----> tile_size: {object_size} / {object_tile_ratio} = {tile_size}")
logger.info(f"----> tile_overlap: {max_size} / {tile_size} = {tile_overlap}")
if tile_overlap >= tiling_parameters.get_metadata("tile_overlap")["max_value"]:
# Use the average object area if the tile overlap is too large to prevent 0 stride.
tile_overlap = object_size / tile_size
logger.info(f"----> (too big) tile_overlap: {object_size} / {tile_size} = {tile_overlap}")
# validate parameters are in range
tile_size = max(
tiling_parameters.get_metadata("tile_size")["min_value"],
min(tiling_parameters.get_metadata("tile_size")["max_value"], tile_size),
)
tile_overlap = max(
tiling_parameters.get_metadata("tile_overlap")["min_value"],
min(tiling_parameters.get_metadata("tile_overlap")["max_value"], tile_overlap),
)
max_num_objects = max(
tiling_parameters.get_metadata("tile_max_number")["min_value"],
min(tiling_parameters.get_metadata("tile_max_number")["max_value"], max_num_objects),
)
tiling_parameters.tile_size = tile_size
tiling_parameters.tile_max_number = max_num_objects
tiling_parameters.tile_overlap = tile_overlap