Source code for datumaro.plugins.missing_annotation_detection

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from typing import List, Optional, Set

from datumaro.components.abstracts.merger import IMatcherContext
from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories
from datumaro.components.annotations.matcher import BboxMatcher, match_segments_more_than_pair
from datumaro.components.dataset_base import DatasetItem, IDataset
from datumaro.components.launcher import Launcher
from datumaro.components.transformer import ModelTransform


[docs] class MissingAnnotationDetection(ModelTransform): """This class is used to find annotations that are missing from the ground truth annotations. To accomplish this, it generates model predictions for the dataset using the given launcher. However, not all of these model predictions can be considered missing annotations since the dataset already contains ground truth annotations and some of the model predictions can be duplicated with them. To identify the missing annotations from the model predictions, this class filters out the predictions that spatially overlap with the ground truth annotations. For example, the follwing example will produce ``[Bbox(1, 1, 1, 1, label=1, attributes={"score": 0.5})]`` as the missing annotations since ``Bbox(0, 0, 1, 1, label=1, attributes={"score": 1.0})`` is overlapped with the ground-truth annotation, ``Bbox(0, 0, 1, 1, label=0)`` (``label_agnostic_matching=True``) .. code-block:: python ground_truth_annotations = [ Bbox(0, 0, 1, 1, label=0), Bbox(1, 0, 1, 1, label=1), Bbox(0, 1, 1, 1, label=2), ] model_predictions = [ Bbox(0, 0, 1, 1, label=1, attributes={"score": 1.0}), Bbox(1, 1, 1, 1, label=1, attributes={"score": 0.5}), ] Args: extractor: The dataset used to find missing labeled annotations. launcher: The launcher used to generate model predictions from the dataset. batch_size: The size of the batches used during processing. pairwise_dist: The distance metric used to measure the distance between two annotations. Typically, the distance metric is Intersection over Union (IoU), which is bounded between 0 and 1. score_threshold: The minimum score required for an annotation to be considered a candidate for missing annotations. label_agnostic_matching: If set to false, annotations with different labels are not matched to determine their spatial overlap. In the above example, ``label_agnostic_matching=False`` will produce ``model_predictions`` as is since ``Bbox(0, 0, 1, 1, label=1, attributes={"score": 1.0})`` has different label with ``Bbox(0, 0, 1, 1, label=0)``. """ def __init__( self, extractor: IDataset, launcher: Launcher, batch_size: int = 1, pairwise_dist: float = 0.75, score_threshold: Optional[float] = None, label_agnostic_matching: bool = True, ): super().__init__(extractor, launcher, batch_size, append_annotation=False) self._score_threshold = score_threshold class LabelAgnosticMatcherContext(IMatcherContext): def get_any_label_name(self, ann: Annotation, label_id: int) -> str: return "" label_categories: LabelCategories = self.categories()[AnnotationType.label] class LabelSpecificMatcherContext(IMatcherContext): def get_any_label_name(self, ann: Annotation, label_id: int) -> str: return label_categories[label_id] self._support_matchers = { AnnotationType.bbox: BboxMatcher( pairwise_dist=pairwise_dist, context=LabelAgnosticMatcherContext() if label_agnostic_matching else LabelSpecificMatcherContext(), match_segments=match_segments_more_than_pair, ), } def _process_batch( self, batch: List[DatasetItem], ) -> List[DatasetItem]: inference = self._launcher.launch( batch=[item for item in batch if self._launcher.type_check(item)] ) for annotations in inference: self._check_annotations(annotations) return [ self.wrap_item( item, annotations=self._find_missing_anns( gt_anns=item.annotations, pseudo_anns=self._apply_score_threshold(annotations), ), ) for item, annotations in zip(batch, inference) ] def _apply_score_threshold(self, annotations: List[Annotation]) -> List[Annotation]: if self._score_threshold is None: return annotations return [ ann for ann in annotations if ann.attributes.get("score", 1.0) > self._score_threshold ] def _find_missing_anns( self, gt_anns: List[Annotation], pseudo_anns: List[Annotation] ) -> List[Annotation]: ids_of_pseudo_anns = set(id(ann) for ann in pseudo_anns) missing_labeled_anns = [] for ann_type, matcher in self._support_matchers.items(): clusters = matcher.match_annotations( [ [ann for ann in gt_anns if ann.type == ann_type], [ann for ann in pseudo_anns if ann.type == ann_type], ] ) for cluster in clusters: ann = self._pick_missing_ann_from_cluster(cluster, ids_of_pseudo_anns) if ann is not None: missing_labeled_anns.append(ann) return missing_labeled_anns @staticmethod def _pick_missing_ann_from_cluster( cluster: List[Annotation], ids_of_pseudo_anns: Set[int] ) -> Optional[Annotation]: assert len(cluster) > 0, "cluster should not be empty." pseudo_label_anns = [] gt_label_anns = [] for ann in cluster: if id(ann) in ids_of_pseudo_anns: pseudo_label_anns.append(ann) else: gt_label_anns.append(ann) if len(gt_label_anns) > 0: return None return max(pseudo_label_anns, key=lambda ann: ann.attributes.get("score", -1))