Source code for datumaro.components.comparator

# Copyright (C) 2023-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
import os
import os.path as osp
from textwrap import wrap
from typing import Dict, List, Set, Tuple
from unittest import TestCase

from attr import attrib, attrs
from tabulate import tabulate

from datumaro.cli.util.project import generate_next_file_name
from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.annotations.matcher import LineMatcher, PointsMatcher, match_segments_pair
from datumaro.components.dataset import Dataset
from datumaro.components.operations import (
    compute_ann_statistics,
    compute_image_statistics,
    match_items_by_id,
    match_items_by_image_hash,
)
from datumaro.components.shift_analyzer import ShiftAnalyzer
from datumaro.util import dump_json_file, filter_dict, find
from datumaro.util.annotation_util import find_instances, max_bbox
from datumaro.util.attrs_util import default_if_none


[docs] @attrs class DistanceComparator: iou_threshold = attrib(converter=float, default=0.5)
[docs] def match_annotations(self, item_a, item_b): return {t: self._match_ann_type(t, item_a, item_b) for t in AnnotationType}
def _match_ann_type(self, t, *args): # pylint: disable=no-value-for-parameter if t == AnnotationType.label: return self.match_labels(*args) elif t == AnnotationType.bbox: return self.match_boxes(*args) elif t == AnnotationType.polygon: return self.match_polygons(*args) elif t == AnnotationType.mask: return self.match_masks(*args) elif t == AnnotationType.points: return self.match_points(*args) elif t == AnnotationType.polyline: return self.match_lines(*args) # pylint: enable=no-value-for-parameter else: raise NotImplementedError("Unexpected annotation type %s" % t) @staticmethod def _get_ann_type(t, item): return [a for a in item.annotations if a.type == t]
[docs] def match_labels(self, item_a, item_b): a_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_a)) b_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_b)) matches = a_labels & b_labels a_unmatched = a_labels - b_labels b_unmatched = b_labels - a_labels return matches, a_unmatched, b_unmatched
def _match_segments(self, t, item_a, item_b): a_boxes = self._get_ann_type(t, item_a) b_boxes = self._get_ann_type(t, item_b) return match_segments_pair(a_boxes, b_boxes, dist_thresh=self.iou_threshold)
[docs] def match_polygons(self, item_a, item_b): return self._match_segments(AnnotationType.polygon, item_a, item_b)
[docs] def match_masks(self, item_a, item_b): return self._match_segments(AnnotationType.mask, item_a, item_b)
[docs] def match_boxes(self, item_a, item_b): return self._match_segments(AnnotationType.bbox, item_a, item_b)
[docs] def match_points(self, item_a, item_b): a_points = self._get_ann_type(AnnotationType.points, item_a) b_points = self._get_ann_type(AnnotationType.points, item_b) instance_map = {} for s in [item_a.annotations, item_b.annotations]: s_instances = find_instances(s) for inst in s_instances: inst_bbox = max_bbox(inst) for ann in inst: instance_map[id(ann)] = [inst, inst_bbox] matcher = PointsMatcher(instance_map=instance_map) return match_segments_pair( a_points, b_points, dist_thresh=self.iou_threshold, distance=matcher.distance )
[docs] def match_lines(self, item_a, item_b): a_lines = self._get_ann_type(AnnotationType.polyline, item_a) b_lines = self._get_ann_type(AnnotationType.polyline, item_b) matcher = LineMatcher() return match_segments_pair( a_lines, b_lines, dist_thresh=self.iou_threshold, distance=matcher.distance )
[docs] @attrs class EqualityComparator: match_images: bool = attrib(kw_only=True, default=False) ignored_fields = attrib(kw_only=True, factory=set, validator=default_if_none(set)) ignored_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) ignored_item_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) all = attrib(kw_only=True, default=False) _test: TestCase = attrib(init=False) errors: list = attrib(init=False) def __attrs_post_init__(self): self._test = TestCase() self._test.maxDiff = None def _match_items(self, a, b): if self.match_images: return match_items_by_image_hash(a, b) else: return match_items_by_id(a, b) def _compare_categories(self, a, b): test = self._test errors = self.errors try: test.assertEqual(sorted(a, key=lambda t: t.value), sorted(b, key=lambda t: t.value)) except AssertionError as e: errors.append({"type": "categories", "message": str(e)}) if AnnotationType.label in a: try: test.assertEqual( a[AnnotationType.label].items, b[AnnotationType.label].items, ) except AssertionError as e: errors.append({"type": "labels", "message": str(e)}) if AnnotationType.mask in a: try: test.assertEqual( a[AnnotationType.mask].colormap, b[AnnotationType.mask].colormap, ) except AssertionError as e: errors.append({"type": "colormap", "message": str(e)}) if AnnotationType.points in a: try: test.assertEqual( a[AnnotationType.points].items, b[AnnotationType.points].items, ) except AssertionError as e: errors.append({"type": "points", "message": str(e)}) def _compare_annotations(self, a, b): ignored_fields = self.ignored_fields ignored_attrs = self.ignored_attrs a_fields = {k: None for k in a.as_dict() if k in ignored_fields} b_fields = {k: None for k in b.as_dict() if k in ignored_fields} if "attributes" not in ignored_fields: a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs) b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs) result = a.wrap(**a_fields) == b.wrap(**b_fields) return result def _compare_items(self, item_a, item_b): test = self._test a_id = (item_a.id, item_a.subset) b_id = (item_b.id, item_b.subset) matched = [] unmatched = [] errors = [] try: test.assertEqual( filter_dict(item_a.attributes, self.ignored_item_attrs), filter_dict(item_b.attributes, self.ignored_item_attrs), ) except AssertionError as e: errors.append({"type": "item_attr", "a_item": a_id, "b_item": b_id, "message": str(e)}) b_annotations = item_b.annotations[:] for ann_a in item_a.annotations: ann_b_candidates = [x for x in item_b.annotations if x.type == ann_a.type] ann_b = find( enumerate(self._compare_annotations(ann_a, x) for x in ann_b_candidates), lambda x: x[1], ) if ann_b is None: unmatched.append( { "item": a_id, "source": "a", "ann": str(ann_a), } ) continue else: ann_b = ann_b_candidates[ann_b[0]] b_annotations.remove(ann_b) # avoid repeats matched.append({"a_item": a_id, "b_item": b_id, "a": str(ann_a), "b": str(ann_b)}) for ann_b in b_annotations: unmatched.append({"item": b_id, "source": "b", "ann": str(ann_b)}) return matched, unmatched, errors @staticmethod def _print_output(output: dict): print("Found:") print("The first project has %s unmatched items" % len(output.get("a_extra_items", []))) print("The second project has %s unmatched items" % len(output.get("b_extra_items", []))) print("%s item conflicts" % len(output.get("errors", []))) print("%s matching annotations" % len(output.get("matches", []))) print("%s mismatching annotations" % len(output.get("mismatches", [])))
[docs] def compare_datasets(self, a, b): self.errors = [] errors = self.errors self._compare_categories(a.categories(), b.categories()) matched = [] unmatched = [] matches, a_unmatched, b_unmatched = self._match_items(a, b) if a.categories().get(AnnotationType.label) != b.categories().get(AnnotationType.label): output = { "mismatches": unmatched, "a_extra_items": sorted(a_unmatched), "b_extra_items": sorted(b_unmatched), "errors": errors, } if self.all: output["matches"] = matched self._print_output(output) return output _dist = lambda s: len(s[1]) + len(s[2]) for a_ids, b_ids in matches: # build distance matrix match_status = {} # (a_id, b_id): [matched, unmatched, errors] a_matches = {a_id: None for a_id in a_ids} b_matches = {b_id: None for b_id in b_ids} for a_id in a_ids: item_a = a.get(*a_id) candidates = {} for b_id in b_ids: item_b = b.get(*b_id) i_m, i_um, i_err = self._compare_items(item_a, item_b) candidates[b_id] = [i_m, i_um, i_err] if len(i_um) == 0: a_matches[a_id] = b_id b_matches[b_id] = a_id matched.extend(i_m) errors.extend(i_err) break match_status[a_id] = candidates # assign for a_id in a_ids: if len(b_ids) == 0: break # find the closest, ignore already assigned matched_b = a_matches[a_id] if matched_b is not None: continue min_dist = -1 for b_id in b_ids: if b_matches[b_id] is not None: continue d = _dist(match_status[a_id][b_id]) if d < min_dist and 0 <= min_dist: continue min_dist = d matched_b = b_id if matched_b is None: continue a_matches[a_id] = matched_b b_matches[matched_b] = a_id m = match_status[a_id][matched_b] matched.extend(m[0]) unmatched.extend(m[1]) errors.extend(m[2]) a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m) b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m) output = { "mismatches": unmatched, "a_extra_items": sorted(a_unmatched), "b_extra_items": sorted(b_unmatched), "errors": errors, } if self.all: output["matches"] = matched self._print_output(output) return output
[docs] @staticmethod def save_compare_report( output: Dict, report_dir: str, ) -> None: """Saves the comparison report to JSON and text files. Args: output: A dictionary containing the comparison data. report_dir: A string representing the directory to save the report files. """ os.makedirs(report_dir, exist_ok=True) output_file = osp.join( report_dir, generate_next_file_name("equality_compare", ext=".json", basedir=report_dir), ) log.info(f"Saving compare json to {output_file}") dump_json_file(output_file, output, indent=True)
[docs] @attrs class TableComparator: """ Class for comparing datasets and generating comparison report table. """ @staticmethod def _extract_labels(dataset: Dataset) -> Set[str]: """Extracts labels from the dataset. Args: dataset: An instance of a Dataset class. Returns: A set of labels present in the dataset. """ label_cat = dataset.categories().get(AnnotationType.label, LabelCategories()) return set(c.name for c in label_cat) @staticmethod def _compute_statistics(dataset: Dataset) -> Tuple[Dict, Dict]: """Computes image and annotation statistics of the dataset. Args: dataset: An instance of a Dataset class. Returns: A tuple containing image statistics and annotation statistics. """ image_stats = compute_image_statistics(dataset) ann_stats = compute_ann_statistics(dataset) return image_stats, ann_stats def _analyze_dataset(self, dataset: Dataset) -> Tuple[str, Set[str], Dict, Dict]: """Analyzes the dataset to get labels, format, and statistics. Args: dataset: An instance of a Dataset class. Returns: A tuple containing Dataset format, set of label names, image statistics, and annotation statistics. """ dataset_format = dataset.format dataset_labels = self._extract_labels(dataset) image_stats, ann_stats = self._compute_statistics(dataset) return dataset_format, dataset_labels, image_stats, ann_stats @staticmethod def _create_table(headers: List[str], rows: List[List[str]]) -> str: """Creates a table with the given headers and rows using the tabulate module. Args: headers: A list containing table headers. rows: A list containing table rows. Returns: A string representation of the table. """ def wrapfunc(item): """Wrap a item consisted of text, returning a list of wrapped lines.""" max_len = 35 return "\n".join(wrap(item, max_len)) wrapped_rows = [] for row in rows: new_row = [wrapfunc(item) for item in row] wrapped_rows.append(new_row) return tabulate(wrapped_rows, headers, tablefmt="grid") @staticmethod def _create_dict(rows: List[List[str]]) -> Dict[str, List[str]]: """Creates a dictionary from the rows of the table. Args: rows: A list containing table rows. Returns: A dictionary where the key is the first element of a row and the value is the rest of the row. """ data_dict = {row[0]: row[1:] for row in rows[1:]} return data_dict def _create_high_level_comparison_table( self, first_info: Tuple, second_info: Tuple ) -> Tuple[str, Dict]: """Generates a high-level comparison table. Args: first_info: A tuple containing information about the first dataset. second_info: A tuple containing information about the second dataset. Returns: A tuple containing the table as a string and a dictionary representing the data of the table. """ first_format, first_labels, first_image_stats, first_ann_stats = first_info second_format, second_labels, second_image_stats, second_ann_stats = second_info headers = ["Field", "First", "Second"] rows = [ ["Format", first_format, second_format], ["Number of classes", str(len(first_labels)), str(len(second_labels))], [ "Common classes", ", ".join(sorted(list(first_labels.intersection(second_labels)))), ", ".join(sorted(list(second_labels.intersection(first_labels)))), ], ["Classes", ", ".join(sorted(first_labels)), ", ".join(sorted(second_labels))], [ "Images count", str(first_image_stats["dataset"]["images count"]), str(second_image_stats["dataset"]["images count"]), ], [ "Unique images count", str(first_image_stats["dataset"]["unique images count"]), str(second_image_stats["dataset"]["unique images count"]), ], [ "Repeated images count", str(first_image_stats["dataset"]["repeated images count"]), str(second_image_stats["dataset"]["repeated images count"]), ], [ "Annotations count", str(first_ann_stats["annotations count"]), str(second_ann_stats["annotations count"]), ], [ "Unannotated images count", str(first_ann_stats["unannotated images count"]), str(second_ann_stats["unannotated images count"]), ], ] table = self._create_table(headers, rows) data_dict = self._create_dict(rows) return table, data_dict def _create_mid_level_comparison_table( self, first_info: Tuple, second_info: Tuple ) -> Tuple[str, Dict]: """Generates a mid-level comparison table. Args: first_info: A tuple containing information about the first dataset. second_info: A tuple containing information about the second dataset. Returns: A tuple containing the table as a string and a dictionary representing the data of the table. """ _, _, first_image_stats, first_ann_stats = first_info _, _, second_image_stats, second_ann_stats = second_info headers = ["Field", "First", "Second"] rows = [] first_subsets = sorted(list(first_image_stats["subsets"].keys())) second_subsets = sorted(list(second_image_stats["subsets"].keys())) subset_names = first_subsets.copy() subset_names.extend(item for item in second_subsets if item not in first_subsets) for subset_name in subset_names: first_subset_data = first_image_stats["subsets"].get(subset_name, {}) second_subset_data = second_image_stats["subsets"].get(subset_name, {}) mean_str_first = ( ", ".join(f"{val:6.2f}" for val in first_subset_data.get("image mean (RGB)", [])) if "image mean (RGB)" in first_subset_data else "" ) std_str_first = ( ", ".join(f"{val:6.2f}" for val in first_subset_data.get("image std (RGB)", [])) if "image std" in first_subset_data else "" ) mean_str_second = ( ", ".join(f"{val:6.2f}" for val in second_subset_data.get("image mean (RGB)", [])) if "image mean (RGB)" in second_subset_data else "" ) std_str_second = ( ", ".join(f"{val:6.2f}" for val in second_subset_data.get("image std", [])) if "image std (RGB)" in second_subset_data else "" ) rows.append([f"{subset_name} - Image Mean (RGB)", mean_str_first, mean_str_second]) rows.append([f"{subset_name} - Image Std (RGB)", std_str_first, std_str_second]) first_labels = sorted(list(first_ann_stats["annotations"]["labels"]["distribution"].keys())) second_labels = sorted( list(second_ann_stats["annotations"]["labels"]["distribution"].keys()) ) label_names = first_labels.copy() label_names.extend(item for item in second_labels if item not in first_labels) for label_name in label_names: count_dist_first = first_ann_stats["annotations"]["labels"]["distribution"].get( label_name, [0, 0.0] ) count_dist_second = second_ann_stats["annotations"]["labels"]["distribution"].get( label_name, [0, 0.0] ) count_first, dist_first = count_dist_first if count_dist_first[0] != 0 else ["", ""] count_second, dist_second = count_dist_second if count_dist_second[0] != 0 else ["", ""] rows.append( [ f"Label - {label_name}", f"imgs: {count_first}, percent: {dist_first:.4f}" if count_first != "" else "", f"imgs: {count_second}, percent: {dist_second:.4f}" if count_second != "" else "", ] ) table = self._create_table(headers, rows) data_dict = self._create_dict(rows) return table, data_dict def _create_low_level_comparison_table( self, first_dataset: Dataset, second_dataset: Dataset ) -> Tuple[str, Dict]: """Generates a low-level comparison table. Args: first_dataset: The first dataset to compare. second_dataset: The second dataset to compare. Returns: A tuple containing the table as a string and a dictionary representing the data of the table. """ shift_analyzer = ShiftAnalyzer() cov_shift = shift_analyzer.compute_covariate_shift([first_dataset, second_dataset]) label_shift = shift_analyzer.compute_label_shift([first_dataset, second_dataset]) headers = ["Field", "Value"] rows = [ ["Covariate shift", str(cov_shift)], ["Label shift", str(label_shift)], ] table = self._create_table(headers, rows) data_dict = self._create_dict(rows) return table, data_dict
[docs] def compare_datasets( self, first: Dataset, second: Dataset, mode: str = "all" ) -> Tuple[str, str, str, Dict]: """Compares two datasets and generates comparison reports. Args: first: The first dataset to compare. second: The second dataset to compare. Returns: A tuple containing high-level table, mid-level table, low-level table, and a dictionary representation of the comparison. """ first_info = self._analyze_dataset(first) second_info = self._analyze_dataset(second) high_level_table, high_level_dict = None, {} mid_level_table, mid_level_dict = None, {} low_level_table, low_level_dict = None, {} if mode in ["high", "all"]: high_level_table, high_level_dict = self._create_high_level_comparison_table( first_info, second_info ) if mode in ["mid", "all"]: mid_level_table, mid_level_dict = self._create_mid_level_comparison_table( first_info, second_info ) if mode in ["low", "all"]: low_level_table, low_level_dict = self._create_low_level_comparison_table(first, second) comparison_dict = dict( high_level=high_level_dict, mid_level=mid_level_dict, low_level=low_level_dict ) print(f"High-level comparison:\n{high_level_table}\n") print(f"Mid-level comparison:\n{mid_level_table}\n") print(f"Low-level comparison:\n{low_level_table}\n") return high_level_table, mid_level_table, low_level_table, comparison_dict
[docs] @staticmethod def save_compare_report( high_level_table: str, mid_level_table: str, low_level_table: str, comparison_dict: Dict, report_dir: str, ) -> None: """Saves the comparison report to JSON and text files. Args: high_level_table: High-level comparison table as a string. mid_level_table: Mid-level comparison table as a string. low_level_table: Low-level comparison table as a string. comparison_dict: A dictionary containing the comparison data. report_dir: A string representing the directory to save the report files. """ os.makedirs(report_dir, exist_ok=True) json_output_file = osp.join( report_dir, generate_next_file_name("table_compare", ext=".json", basedir=report_dir) ) txt_output_file = osp.join( report_dir, generate_next_file_name("table_compare", ext=".txt", basedir=report_dir) ) log.info(f"Saving compare json to {json_output_file}") log.info(f"Saving compare table to {txt_output_file}") dump_json_file(json_output_file, comparison_dict, indent=True) with open(txt_output_file, "w") as f: f.write(f"High-level Comparison:\n{high_level_table}\n\n") f.write(f"Mid-level Comparison:\n{mid_level_table}\n\n") f.write(f"Low-level Comparison:\n{low_level_table}\n\n")