# Copyright (C) 2021-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
from copy import deepcopy
import numpy as np
from datumaro.components.annotation import (
AnnotationType,
GroupType,
LabelCategories,
TabularCategories,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.errors import (
AttributeDefinedButNotFound,
BrokenAnnotation,
EmptyCaption,
EmptyLabel,
FarFromAttrMean,
FarFromCaptionMean,
FarFromLabelMean,
FewSamplesInAttribute,
FewSamplesInCaption,
FewSamplesInLabel,
ImbalancedAttribute,
ImbalancedCaptions,
ImbalancedDistInAttribute,
ImbalancedDistInCaption,
ImbalancedDistInLabel,
ImbalancedLabels,
InvalidValue,
LabelDefinedButNotFound,
MissingAnnotation,
MissingAttribute,
MissingLabelCategories,
MultiLabelAnnotations,
NegativeLength,
OnlyOneAttributeValue,
OnlyOneLabel,
OutlierInCaption,
RedundanciesInCaption,
UndefinedAttribute,
UndefinedLabel,
)
from datumaro.components.validator import Severity, TaskType, Validator
from datumaro.util import parse_str_enum_value
from datumaro.util.tabular_util import emoji_pattern
DEFAULT_LABEL_GROUP = "default"
class _TaskValidator(Validator, CliPlugin):
DEFAULT_FEW_SAMPLES_THR = 1
DEFAULT_IMBALANCE_RATIO_THR = 50
DEFAULT_FAR_FROM_MEAN_THR = 5
DEFAULT_DOMINANCE_RATIO_THR = 0.8
DEFAULT_TOPK_BINS = 0.1
# statistics templates
numerical_stat_template = {
"items_far_from_mean": {},
"mean": None,
"stdev": None,
"min": None,
"max": None,
"median": None,
"histogram": {
"bins": [],
"counts": [],
},
"distribution": [],
}
"""
A base class for task-specific validators.
Attributes
----------
task_type : str or TaskType
task type (ie. classification, detection, segmentation, tabular)
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
"-fs",
"--few-samples-thr",
default=cls.DEFAULT_FEW_SAMPLES_THR,
type=int,
help="Threshold for giving a warning for minimum number of "
"samples per class (default: %(default)s)",
)
parser.add_argument(
"-ir",
"--imbalance-ratio-thr",
default=cls.DEFAULT_IMBALANCE_RATIO_THR,
type=int,
help="Threshold for giving data imbalance warning. "
"IR(imbalance ratio) = majority/minority "
"(default: %(default)s)",
)
parser.add_argument(
"-m",
"--far-from-mean-thr",
default=cls.DEFAULT_FAR_FROM_MEAN_THR,
type=float,
help="Threshold for giving a warning that data is far from mean. "
"A constant used to define mean +/- k * standard deviation "
"(default: %(default)s)",
)
parser.add_argument(
"-dr",
"--dominance-ratio-thr",
default=cls.DEFAULT_DOMINANCE_RATIO_THR,
type=float,
help="Threshold for giving a warning for bounding box imbalance. "
"Dominace_ratio = ratio of Top-k bin to total in histogram "
"(default: %(default)s)",
)
parser.add_argument(
"-k",
"--topk-bins",
default=cls.DEFAULT_TOPK_BINS,
type=float,
help="Ratio of bins with the highest number of data "
"to total bins in the histogram. A value in the range [0, 1] "
"(default: %(default)s)",
)
return parser
def __init__(
self,
task_type,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
"""
Validator
Parameters
---------------
few_samples_thr: int
minimum number of samples per class
warn user when samples per class is less than threshold
imbalance_ratio_thr: int
ratio of majority attribute to minority attribute
warn user when annotations are unevenly distributed
far_from_mean_thr: float
constant used to define mean +/- m * stddev
warn user when there are too big or small values
dominance_ratio_thr: float
ratio of Top-k bin to total
warn user when dominance ratio is over threshold
topk_bins: float
ratio of selected bins with most item number to total bins
warn user when values are not evenly distributed
"""
self.task_type = parse_str_enum_value(task_type, TaskType, default=TaskType.classification)
if self.task_type == TaskType.classification:
self.ann_types = {AnnotationType.label}
self.str_ann_type = "label"
elif self.task_type == TaskType.detection:
self.ann_types = {AnnotationType.bbox}
self.str_ann_type = "bounding box"
elif self.task_type == TaskType.segmentation:
self.ann_types = {AnnotationType.mask, AnnotationType.polygon, AnnotationType.ellipse}
self.str_ann_type = "mask or polygon or ellipse"
elif self.task_type == TaskType.tabular:
self.ann_types = {AnnotationType.label, AnnotationType.caption}
self.str_ann_type = "label or caption"
if few_samples_thr is None:
few_samples_thr = self.DEFAULT_FEW_SAMPLES_THR
if imbalance_ratio_thr is None:
imbalance_ratio_thr = self.DEFAULT_IMBALANCE_RATIO_THR
if far_from_mean_thr is None:
far_from_mean_thr = self.DEFAULT_FAR_FROM_MEAN_THR
if dominance_ratio_thr is None:
dominance_ratio_thr = self.DEFAULT_DOMINANCE_RATIO_THR
if topk_bins is None:
topk_bins = self.DEFAULT_TOPK_BINS
self.few_samples_thr = few_samples_thr
self.imbalance_ratio_thr = imbalance_ratio_thr
self.far_from_mean_thr = far_from_mean_thr
self.dominance_thr = dominance_ratio_thr
self.topk_bins_ratio = topk_bins
def _compute_common_statistics(self, dataset):
defined_attr_template = {"items_missing_attribute": [], "distribution": {}}
undefined_attr_template = {"items_with_undefined_attr": [], "distribution": {}}
undefined_label_template = {
"count": 0,
"items_with_undefined_label": [],
}
stats = {
"label_distribution": {
"defined_labels": {},
"undefined_labels": {},
},
"attribute_distribution": {"defined_attributes": {}, "undefined_attributes": {}},
}
stats["total_ann_count"] = 0
stats["items_missing_annotation"] = []
label_dist = stats["label_distribution"]
defined_label_dist = label_dist["defined_labels"]
undefined_label_dist = label_dist["undefined_labels"]
attr_dist = stats["attribute_distribution"]
defined_attr_dist = attr_dist["defined_attributes"]
undefined_attr_dist = attr_dist["undefined_attributes"]
label_categories = dataset.categories().get(AnnotationType.label, LabelCategories())
base_valid_attrs = label_categories.attributes
for category in label_categories:
defined_label_dist[category.name] = 0
filtered_anns = []
for item in dataset:
item_key = (item.id, item.subset)
annotations = []
for ann in item.annotations:
if ann.type in self.ann_types:
annotations.append(ann)
ann_count = len(annotations)
filtered_anns.append((item_key, annotations))
if ann_count == 0:
stats["items_missing_annotation"].append(item_key)
stats["total_ann_count"] += ann_count
for ann in annotations:
if not 0 <= ann.label < len(label_categories):
label_name = ann.label
label_stats = undefined_label_dist.setdefault(
ann.label, deepcopy(undefined_label_template)
)
label_stats["items_with_undefined_label"].append(item_key)
label_stats["count"] += 1
valid_attrs = set()
missing_attrs = set()
else:
label_name = label_categories[ann.label].name
defined_label_dist[label_name] += 1
defined_attr_stats = defined_attr_dist.setdefault(label_name, {})
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
ann_attrs = getattr(ann, "attributes", {}).keys()
missing_attrs = valid_attrs.difference(ann_attrs)
for attr in valid_attrs:
defined_attr_stats.setdefault(attr, deepcopy(defined_attr_template))
for attr in missing_attrs:
attr_dets = defined_attr_stats[attr]
attr_dets["items_missing_attribute"].append(item_key)
for attr, value in ann.attributes.items():
if attr not in valid_attrs:
undefined_attr_stats = undefined_attr_dist.setdefault(label_name, {})
attr_dets = undefined_attr_stats.setdefault(
attr, deepcopy(undefined_attr_template)
)
attr_dets["items_with_undefined_attr"].append(item_key)
else:
attr_dets = defined_attr_stats[attr]
attr_dets["distribution"].setdefault(str(value), 0)
attr_dets["distribution"][str(value)] += 1
return stats, filtered_anns
def _generate_common_reports(self, stats):
"""
Validates the dataset for classification tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats: Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = []
# report for dataset
reports += self._check_missing_label_categories(stats)
# report for item
reports += self._check_missing_annotation(stats)
# report for label
reports += self._check_undefined_label(stats)
reports += self._check_label_defined_but_not_found(stats)
reports += self._check_only_one_label(stats)
reports += self._check_few_samples_in_label(stats)
reports += self._check_imbalanced_labels(stats)
# report for attributes
attr_dist = stats["attribute_distribution"]
defined_attr_dist = attr_dist["defined_attributes"]
undefined_attr_dist = attr_dist["undefined_attributes"]
defined_labels = defined_attr_dist.keys()
for label_name in defined_labels:
attr_stats = defined_attr_dist[label_name]
reports += self._check_attribute_defined_but_not_found(label_name, attr_stats)
for attr_name, attr_dets in attr_stats.items():
reports += self._check_missing_attribute(label_name, attr_name, attr_dets)
reports += self._check_only_one_attribute(label_name, attr_name, attr_dets)
reports += self._check_few_samples_in_attribute(label_name, attr_name, attr_dets)
reports += self._check_imbalanced_attribute(label_name, attr_name, attr_dets)
for label_name, attr_stats in undefined_attr_dist.items():
for attr_name, attr_dets in attr_stats.items():
reports += self._check_undefined_attribute(label_name, attr_name, attr_dets)
return reports
def _generate_validation_report(self, error, *args, **kwargs):
return [error(*args, **kwargs)]
def _check_missing_label_categories(self, stats):
validation_reports = []
if len(stats["label_distribution"]["defined_labels"]) == 0:
validation_reports += self._generate_validation_report(
MissingLabelCategories, Severity.error
)
return validation_reports
def _check_missing_annotation(self, stats):
validation_reports = []
items_missing = stats["items_missing_annotation"]
for item_id, item_subset in items_missing:
validation_reports += self._generate_validation_report(
MissingAnnotation, Severity.warning, item_id, item_subset, self.str_ann_type
)
return validation_reports
def _check_missing_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
items_missing_attr = attr_dets["items_missing_attribute"]
for item_id, item_subset in items_missing_attr:
details = (item_subset, label_name, attr_name)
validation_reports += self._generate_validation_report(
MissingAttribute, Severity.warning, item_id, *details
)
return validation_reports
def _check_undefined_label(self, stats):
validation_reports = []
undefined_label_dist = stats["label_distribution"]["undefined_labels"]
for label_name, label_stats in undefined_label_dist.items():
for item_id, item_subset in label_stats["items_with_undefined_label"]:
details = (item_subset, label_name)
validation_reports += self._generate_validation_report(
UndefinedLabel, Severity.error, item_id, *details
)
return validation_reports
def _check_undefined_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
items_with_undefined_attr = attr_dets["items_with_undefined_attr"]
for item_id, item_subset in items_with_undefined_attr:
details = (item_subset, label_name, attr_name)
validation_reports += self._generate_validation_report(
UndefinedAttribute, Severity.error, item_id, *details
)
return validation_reports
def _check_label_defined_but_not_found(self, stats):
validation_reports = []
count_by_defined_labels = stats["label_distribution"]["defined_labels"]
labels_not_found = [
label_name for label_name, count in count_by_defined_labels.items() if count == 0
]
for label_name in labels_not_found:
validation_reports += self._generate_validation_report(
LabelDefinedButNotFound, Severity.warning, label_name
)
return validation_reports
def _check_attribute_defined_but_not_found(self, label_name, attr_stats):
validation_reports = []
attrs_not_found = [
attr_name
for attr_name, attr_dets in attr_stats.items()
if len(attr_dets["distribution"]) == 0
]
for attr_name in attrs_not_found:
details = (label_name, attr_name)
validation_reports += self._generate_validation_report(
AttributeDefinedButNotFound, Severity.warning, *details
)
return validation_reports
def _check_only_one_label(self, stats):
validation_reports = []
count_by_defined_labels = stats["label_distribution"]["defined_labels"]
labels_found = [
label_name for label_name, count in count_by_defined_labels.items() if count > 0
]
if len(labels_found) == 1:
validation_reports += self._generate_validation_report(
OnlyOneLabel, Severity.info, labels_found[0]
)
return validation_reports
def _check_only_one_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
values = list(attr_dets["distribution"].keys())
if len(values) == 1:
details = (label_name, attr_name, values[0])
validation_reports += self._generate_validation_report(
OnlyOneAttributeValue, Severity.info, *details
)
return validation_reports
def _check_few_samples_in_label(self, stats):
validation_reports = []
thr = self.few_samples_thr
defined_label_dist = stats["label_distribution"]["defined_labels"]
labels_with_few_samples = [
(label_name, count)
for label_name, count in defined_label_dist.items()
if 0 < count <= thr
]
for label_name, count in labels_with_few_samples:
validation_reports += self._generate_validation_report(
FewSamplesInLabel, Severity.info, label_name, count
)
return validation_reports
def _check_few_samples_in_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
thr = self.few_samples_thr
attr_values_with_few_samples = [
(attr_value, count)
for attr_value, count in attr_dets["distribution"].items()
if count <= thr
]
for attr_value, count in attr_values_with_few_samples:
details = (label_name, attr_name, attr_value, count)
validation_reports += self._generate_validation_report(
FewSamplesInAttribute, Severity.info, *details
)
return validation_reports
def _check_imbalanced_labels(self, stats):
validation_reports = []
thr = self.imbalance_ratio_thr
defined_label_dist = stats["label_distribution"]["defined_labels"]
count_by_defined_labels = [count for label, count in defined_label_dist.items()]
if len(count_by_defined_labels) == 0:
return validation_reports
count_max = np.max(count_by_defined_labels)
count_min = np.min(count_by_defined_labels)
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(ImbalancedLabels, Severity.info)
return validation_reports
def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
thr = self.imbalance_ratio_thr
count_by_defined_attr = list(attr_dets["distribution"].values())
if len(count_by_defined_attr) == 0:
return validation_reports
count_max = np.max(count_by_defined_attr)
count_min = np.min(count_by_defined_attr)
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedAttribute, Severity.info, label_name, attr_name
)
return validation_reports
[docs]
class ClassificationValidator(_TaskValidator):
"""
A specific validator class for classification task.
"""
def __init__(
self,
task_type=TaskType.classification,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
super().__init__(
task_type=task_type,
few_samples_thr=few_samples_thr,
imbalance_ratio_thr=imbalance_ratio_thr,
far_from_mean_thr=far_from_mean_thr,
dominance_ratio_thr=dominance_ratio_thr,
topk_bins=topk_bins,
)
[docs]
def compute_statistics(self, dataset):
"""
Computes statistics of the dataset for the classification task.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
stats, filtered_anns = self._compute_common_statistics(dataset)
label_cat = dataset.categories()[AnnotationType.label]
label_groups = label_cat.label_groups
label_name_to_group = {}
for label_group in label_groups:
for idx, label_name in enumerate(label_group.labels):
if label_group.group_type == GroupType.EXCLUSIVE:
label_name_to_group[label_name] = label_group.name
else:
label_name_to_group[label_name] = label_group.name + f"_{idx}"
undefined_label_name = list(stats["label_distribution"]["undefined_labels"].keys())
stats["items_with_multiple_labels"] = []
for item_key, anns in filtered_anns:
occupied_groups = set()
for ann in anns:
if ann.label in undefined_label_name:
continue
label_name = label_cat[ann.label].name
label_group = label_name_to_group.get(label_name, DEFAULT_LABEL_GROUP)
if label_group in occupied_groups:
stats["items_with_multiple_labels"].append(item_key)
break
occupied_groups.add(label_group)
return stats
[docs]
def generate_reports(self, stats):
"""
Validates the dataset for classification tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats: Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = self._generate_common_reports(stats)
reports += self._check_multi_label_annotations(stats)
return reports
def _check_multi_label_annotations(self, stats):
validation_reports = []
items_with_multiple_labels = stats["items_with_multiple_labels"]
for item_id, item_subset in items_with_multiple_labels:
validation_reports += self._generate_validation_report(
MultiLabelAnnotations, Severity.error, item_id, item_subset
)
return validation_reports
[docs]
class DetectionValidator(_TaskValidator):
"""
A specific validator class for detection task.
"""
def __init__(
self,
task_type=TaskType.detection,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
super().__init__(
task_type=task_type,
few_samples_thr=few_samples_thr,
imbalance_ratio_thr=imbalance_ratio_thr,
far_from_mean_thr=far_from_mean_thr,
dominance_ratio_thr=dominance_ratio_thr,
topk_bins=topk_bins,
)
self.point_template = {
"width": deepcopy(self.numerical_stat_template),
"height": deepcopy(self.numerical_stat_template),
"area(wxh)": deepcopy(self.numerical_stat_template),
"ratio(w/h)": deepcopy(self.numerical_stat_template),
"short": deepcopy(self.numerical_stat_template),
"long": deepcopy(self.numerical_stat_template),
}
[docs]
def compute_statistics(self, dataset):
"""
Computes statistics of the dataset for the detection task.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
stats, filtered_items = self._compute_common_statistics(dataset)
stats["items_with_negative_length"] = {}
stats["items_with_invalid_value"] = {}
stats["point_distribution_in_label"] = {}
stats["point_distribution_in_attribute"] = {}
stats["point_distribution_in_dataset_item"] = {}
self.items = filtered_items
def _generate_ann_bbox_info(_x, _y, _w, _h, area, ratio, _short, _long):
return {
"x": _x,
"y": _y,
"width": _w,
"height": _h,
"area(wxh)": area,
"ratio(w/h)": ratio,
"short": _short,
"long": _long,
}
def _update_bbox_stats_by_label(item_key, ann, bbox_label_stats):
bbox_has_error = False
_x, _y, _w, _h = ann.get_bbox()
area = ann.get_area()
if _h != 0 and _h != float("inf"):
ratio = _w / _h
else:
ratio = float("nan")
_short = _w if _w < _h else _h
_long = _w if _w > _h else _h
ann_bbox_info = _generate_ann_bbox_info(_x, _y, _w, _h, area, ratio, _short, _long)
items_w_invalid_val = stats["items_with_invalid_value"]
for prop, val in ann_bbox_info.items():
if val == float("inf") or np.isnan(val):
bbox_has_error = True
anns_w_invalid_val = items_w_invalid_val.setdefault(item_key, {})
invalid_props = anns_w_invalid_val.setdefault(ann.id, [])
invalid_props.append(prop)
items_w_neg_len = stats["items_with_negative_length"]
for prop in ["width", "height"]:
val = ann_bbox_info[prop]
if val < 1:
bbox_has_error = True
anns_w_neg_len = items_w_neg_len.setdefault(item_key, {})
neg_props = anns_w_neg_len.setdefault(ann.id, {})
neg_props[prop] = val
if not bbox_has_error:
ann_bbox_info.pop("x")
ann_bbox_info.pop("y")
self._update_prop_distributions(ann_bbox_info, bbox_label_stats)
return ann_bbox_info, bbox_has_error
# Collect property distribution
label_categories = dataset.categories().get(AnnotationType.label, LabelCategories())
self._compute_prop_dist(label_categories, stats, _update_bbox_stats_by_label)
# Compute property statistics from distribution
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr)
def _is_valid_bbox(item_key, ann):
has_defined_label = 0 <= ann.label < len(label_categories)
if not has_defined_label:
return False
bbox_has_neg_len = ann.id in stats["items_with_negative_length"].get(item_key, {})
bbox_has_invalid_val = ann.id in stats["items_with_invalid_value"].get(item_key, {})
return not (bbox_has_neg_len or bbox_has_invalid_val)
def _update_bbox_props_far_from_mean(item_key, ann):
base_valid_attrs = label_categories.attributes
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
label_name = label_categories[ann.label].name
bbox_label_stats = dist_by_label[label_name]
_x, _y, _w, _h = ann.get_bbox()
area = ann.get_area()
ratio = _w / _h
_short = _w if _w < _h else _h
_long = _w if _w > _h else _h
ann_bbox_info = _generate_ann_bbox_info(_x, _y, _w, _h, area, ratio, _short, _long)
ann_bbox_info.pop("x")
ann_bbox_info.pop("y")
for prop, val in ann_bbox_info.items():
prop_stats = bbox_label_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
for attr, value in ann.attributes.items():
if attr in valid_attrs:
bbox_attr_stats = dist_by_attr[label_name][attr]
bbox_val_stats = bbox_attr_stats[str(value)]
for prop, val in ann_bbox_info.items():
prop_stats = bbox_val_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
# Compute far_from_mean from property
for item_key, annotations in self.items:
for ann in annotations:
if _is_valid_bbox(item_key, ann):
_update_bbox_props_far_from_mean(item_key, ann)
return stats
[docs]
def generate_reports(self, stats):
"""
Validates the dataset for detection tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats : Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = self._generate_common_reports(stats)
reports += self._check_negative_length(stats)
reports += self._check_invalid_value(stats)
defined_attr_dist = stats["attribute_distribution"]["defined_attributes"]
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
defined_labels = defined_attr_dist.keys()
for label_name in defined_labels:
bbox_label_stats = dist_by_label[label_name]
bbox_attr_label = dist_by_attr.get(label_name, {})
reports += self._check_far_from_label_mean(label_name, bbox_label_stats)
reports += self._check_imbalanced_dist_in_label(label_name, bbox_label_stats)
for attr_name, bbox_attr_stats in bbox_attr_label.items():
reports += self._check_far_from_attr_mean(label_name, attr_name, bbox_attr_stats)
reports += self._check_imbalanced_dist_in_attr(
label_name, attr_name, bbox_attr_stats
)
return reports
def _update_prop_distributions(self, curr_stats, target_stats):
for prop, val in curr_stats.items():
prop_stats = target_stats[prop]
prop_stats["distribution"].append(val)
def _compute_prop_dist(self, label_categories, stats, update_stats_by_label):
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
point_dist_in_item = stats["point_distribution_in_dataset_item"]
base_valid_attrs = label_categories.attributes
for item_key, annotations in self.items:
ann_count = len(annotations)
point_dist_in_item[item_key] = ann_count
for ann in annotations:
if not 0 <= ann.label < len(label_categories):
label_name = ann.label
valid_attrs = set()
else:
label_name = label_categories[ann.label].name
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
point_label_stats = dist_by_label.setdefault(
label_name, deepcopy(self.point_template)
)
ann_point_info, _has_error = update_stats_by_label(
item_key, ann, point_label_stats
)
for attr, value in ann.attributes.items():
if attr in valid_attrs:
point_attr_label = dist_by_attr.setdefault(label_name, {})
point_attr_stats = point_attr_label.setdefault(attr, {})
point_val_stats = point_attr_stats.setdefault(
str(value), deepcopy(self.point_template)
)
if not _has_error:
self._update_prop_distributions(ann_point_info, point_val_stats)
def _compute_prop_stats_from_dist(self, dist_by_label, dist_by_attr):
for label_name, stats in dist_by_label.items():
prop_stats_list = list(stats.values())
attr_label = dist_by_attr.get(label_name, {})
for vals in attr_label.values():
for val_stats in vals.values():
prop_stats_list += list(val_stats.values())
for prop_stats in prop_stats_list:
prop_dist = prop_stats.pop("distribution", [])
if len(prop_dist) > 0:
prop_stats["mean"] = np.mean(prop_dist)
prop_stats["stdev"] = np.std(prop_dist)
prop_stats["min"] = np.min(prop_dist)
prop_stats["max"] = np.max(prop_dist)
prop_stats["median"] = np.median(prop_dist)
counts, bins = np.histogram(prop_dist)
prop_stats["histogram"]["bins"] = bins.tolist()
prop_stats["histogram"]["counts"] = counts.tolist()
def _compute_far_from_mean(self, prop_stats, val, item_key, ann):
def _far_from_mean(val, mean, stdev):
thr = self.far_from_mean_thr
return val > mean + (thr * stdev) or val < mean - (thr * stdev)
mean = prop_stats["mean"]
stdev = prop_stats["stdev"]
if _far_from_mean(val, mean, stdev):
items_far_from_mean = prop_stats["items_far_from_mean"]
far_from_mean = items_far_from_mean.setdefault(item_key, {})
far_from_mean[ann.id] = val
def _check_negative_length(self, stats):
validation_reports = []
items_w_neg_len = stats["items_with_negative_length"]
for item_dets, anns_w_neg_len in items_w_neg_len.items():
item_id, item_subset = item_dets
for ann_id, props in anns_w_neg_len.items():
for prop, val in props.items():
val = round(val, 2)
details = (item_subset, ann_id, f"{self.str_ann_type} {prop}", val)
validation_reports += self._generate_validation_report(
NegativeLength, Severity.error, item_id, *details
)
return validation_reports
def _check_invalid_value(self, stats):
validation_reports = []
items_w_invalid_val = stats["items_with_invalid_value"]
for item_dets, anns_w_invalid_val in items_w_invalid_val.items():
item_id, item_subset = item_dets
for ann_id, props in anns_w_invalid_val.items():
for prop in props:
details = (item_subset, ann_id, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
InvalidValue, Severity.error, item_id, *details
)
return validation_reports
def _check_imbalanced_dist_in_label(self, label_name, label_stats):
validation_reports = []
thr = self.dominance_thr
topk_ratio = self.topk_bins_ratio
for prop, prop_stats in label_stats.items():
value_counts = prop_stats["histogram"]["counts"]
n_bucket = len(value_counts)
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))
if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio >= thr:
details = (label_name, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInLabel, Severity.info, *details
)
return validation_reports
def _check_imbalanced_dist_in_attr(self, label_name, attr_name, attr_stats):
validation_reports = []
thr = self.dominance_thr
topk_ratio = self.topk_bins_ratio
for attr_value, value_stats in attr_stats.items():
for prop, prop_stats in value_stats.items():
value_counts = prop_stats["histogram"]["counts"]
n_bucket = len(value_counts)
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))
if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio >= thr:
details = (label_name, attr_name, attr_value, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInAttribute, Severity.info, *details
)
return validation_reports
def _check_far_from_label_mean(self, label_name, label_stats):
validation_reports = []
for prop, prop_stats in label_stats.items():
items_far_from_mean = prop_stats["items_far_from_mean"]
if prop_stats["mean"] is not None:
mean = round(prop_stats["mean"], 2)
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (
item_subset,
label_name,
ann_id,
f"{self.str_ann_type} {prop}",
mean,
val,
)
validation_reports += self._generate_validation_report(
FarFromLabelMean, Severity.warning, item_id, *details
)
return validation_reports
def _check_far_from_attr_mean(self, label_name, attr_name, attr_stats):
validation_reports = []
for attr_value, value_stats in attr_stats.items():
for prop, prop_stats in value_stats.items():
items_far_from_mean = prop_stats["items_far_from_mean"]
if prop_stats["mean"] is not None:
mean = round(prop_stats["mean"], 2)
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (
item_subset,
label_name,
ann_id,
attr_name,
attr_value,
f"{self.str_ann_type} {prop}",
mean,
val,
)
validation_reports += self._generate_validation_report(
FarFromAttrMean, Severity.warning, item_id, *details
)
return validation_reports
[docs]
class SegmentationValidator(DetectionValidator):
"""
A specific validator class for (instance) segmentation task.
"""
def __init__(
self,
task_type=TaskType.segmentation,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
super().__init__(
task_type=task_type,
few_samples_thr=few_samples_thr,
imbalance_ratio_thr=imbalance_ratio_thr,
far_from_mean_thr=far_from_mean_thr,
dominance_ratio_thr=dominance_ratio_thr,
topk_bins=topk_bins,
)
self.point_template = {
"area": deepcopy(self.numerical_stat_template),
"width": deepcopy(self.numerical_stat_template),
"height": deepcopy(self.numerical_stat_template),
}
[docs]
def compute_statistics(self, dataset):
"""
Computes statistics of the dataset for the segmentation task.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
stats, filtered_items = self._compute_common_statistics(dataset)
stats["items_with_invalid_value"] = {}
stats["point_distribution_in_label"] = {}
stats["point_distribution_in_attribute"] = {}
stats["point_distribution_in_dataset_item"] = {}
self.items = filtered_items
def _generate_ann_mask_info(area, _w, _h):
return {
"area": area,
"width": _w,
"height": _h,
}
def _update_mask_stats_by_label(item_key, ann, mask_label_stats):
mask_has_error = False
_, _, _w, _h = ann.get_bbox()
# Detete the following block when #226 is resolved
# https://github.com/openvinotoolkit/datumaro/issues/226
if ann.type == AnnotationType.mask:
_w += 1
_h += 1
area = ann.get_area()
ann_mask_info = _generate_ann_mask_info(area, _w, _h)
items_w_invalid_val = stats["items_with_invalid_value"]
for prop, val in ann_mask_info.items():
if val == float("inf") or np.isnan(val):
mask_has_error = True
anns_w_invalid_val = items_w_invalid_val.setdefault(item_key, {})
invalid_props = anns_w_invalid_val.setdefault(ann.id, [])
invalid_props.append(prop)
if not mask_has_error:
self._update_prop_distributions(ann_mask_info, mask_label_stats)
return ann_mask_info, mask_has_error
# Collect property distribution
label_categories = dataset.categories().get(AnnotationType.label, LabelCategories())
self._compute_prop_dist(label_categories, stats, _update_mask_stats_by_label)
# Compute property statistics from distribution
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr)
def _is_valid_mask(item_key, ann):
has_defined_label = 0 <= ann.label < len(label_categories)
if not has_defined_label:
return False
mask_has_invalid_val = ann.id in stats["items_with_invalid_value"].get(item_key, {})
return not mask_has_invalid_val
def _update_mask_props_far_from_mean(item_key, ann):
base_valid_attrs = label_categories.attributes
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
label_name = label_categories[ann.label].name
mask_label_stats = dist_by_label[label_name]
_, _, _w, _h = ann.get_bbox()
# Detete the following block when #226 is resolved
# https://github.com/openvinotoolkit/datumaro/issues/226
if ann.type == AnnotationType.mask:
_w += 1
_h += 1
area = ann.get_area()
ann_mask_info = _generate_ann_mask_info(area, _w, _h)
for prop, val in ann_mask_info.items():
prop_stats = mask_label_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
for attr, value in ann.attributes.items():
if attr in valid_attrs:
mask_attr_stats = dist_by_attr[label_name][attr]
mask_val_stats = mask_attr_stats[str(value)]
for prop, val in ann_mask_info.items():
prop_stats = mask_val_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
for item_key, annotations in self.items:
for ann in annotations:
if _is_valid_mask(item_key, ann):
_update_mask_props_far_from_mean(item_key, ann)
return stats
[docs]
def generate_reports(self, stats):
"""
Validates the dataset for segmentation tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats : Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = self._generate_common_reports(stats)
reports += self._check_invalid_value(stats)
defined_attr_dist = stats["attribute_distribution"]["defined_attributes"]
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
defined_labels = defined_attr_dist.keys()
for label_name in defined_labels:
mask_label_stats = dist_by_label[label_name]
mask_attr_label = dist_by_attr.get(label_name, {})
reports += self._check_far_from_label_mean(label_name, mask_label_stats)
reports += self._check_imbalanced_dist_in_label(label_name, mask_label_stats)
for attr_name, mask_attr_stats in mask_attr_label.items():
reports += self._check_far_from_attr_mean(label_name, attr_name, mask_attr_stats)
reports += self._check_imbalanced_dist_in_attr(
label_name, attr_name, mask_attr_stats
)
return reports
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, List, Optional
from nltk.corpus import stopwords
[docs]
@dataclass
class TabularValidationStats:
total_ann_count: int = field(default=0)
items_missing_annotation: List[Any] = field(default_factory=list)
[docs]
@classmethod
def create_with_dataset(cls, dataset):
instance = cls()
instance.__post_init__(dataset)
return instance
def __post_init__(self, dataset: Optional[Any] = None):
if dataset:
self.label_categories = dataset.categories().get(
AnnotationType.label, LabelCategories()
)
self.tabular_categories = dataset.categories().get(
AnnotationType.caption, TabularCategories()
)
self.label_columns = list({item.parent for item in self.label_categories.items})
self.caption_columns = [cat.name for cat in self.tabular_categories]
self.defined_labels = {cat.name: 0 for cat in self.label_categories}
self.empty_labels = {cat: [] for cat in self.label_columns}
self.defined_captions = {cat: 0 for cat in self.caption_columns}
self.empty_captions = {cat: [] for cat in self.caption_columns}
self.redundancies = {
cat: {"stopword": [], "url": [], "html": [], "emoji": []}
for cat in self.caption_columns
}
[docs]
def to_dict(self):
empty_labels = self._build_empty_labels_dict(self.empty_labels, "items_with_empty_label")
empty_captions = self._build_empty_labels_dict(
self.empty_captions, "items_with_empty_caption"
)
redundancies = self._build_redundancies_dict(self.redundancies)
return {
"total_ann_count": self.total_ann_count,
"items_missing_annotation": self.items_missing_annotation,
"items_broken_annotation": self._collect_broken_annotations(),
"label_distribution": {
"defined_labels": self.defined_labels,
"empty_labels": empty_labels,
},
"caption_distribution": {
"defined_captions": self.defined_captions,
"empty_captions": empty_captions,
"redundancies": redundancies,
},
}
def _build_empty_labels_dict(self, empty_dict, key_name):
result = defaultdict(dict)
for label, items in empty_dict.items():
result[label]["count"] = len(items)
result[label][key_name] = list(items)
return result
def _build_redundancies_dict(self, redundancies):
result = defaultdict(lambda: defaultdict(dict))
for caption, items in redundancies.items():
for key in ["stopword", "url", "html", "emoji"]:
result[caption][key]["count"] = len(items[key])
result[caption][key]["items_with_redundancies"] = list(items[key])
return result
def _collect_broken_annotations(self):
broken_annotations = set()
for items in [self.empty_labels, self.empty_captions]:
for key, value in items.items():
broken_annotations.update(value)
return list(broken_annotations)
[docs]
class TabularValidator(_TaskValidator):
"""
A specific validator class for tabular dataset.
"""
def __init__(
self,
task_type=TaskType.tabular,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
super().__init__(
task_type=task_type,
few_samples_thr=few_samples_thr,
imbalance_ratio_thr=imbalance_ratio_thr,
far_from_mean_thr=far_from_mean_thr,
dominance_ratio_thr=dominance_ratio_thr,
topk_bins=topk_bins,
)
self.numerical_stat_template = {
"items_far_from_mean": {},
"mean": None,
"stdev": None,
"min": None,
"max": None,
"median": None,
"histogram": {
"bins": [],
"counts": [],
},
"distribution": [],
"items_outlier": {},
"outlier": None,
}
self.value_template = {"value": deepcopy(self.numerical_stat_template)}
def _compute_common_statistics(self, dataset):
stats = TabularValidationStats.create_with_dataset(dataset=dataset)
filtered_items = []
for item in dataset:
item_key = (item.id, item.subset)
label_check = {cat: 0 for cat in stats.label_columns}
annotations = [ann for ann in item.annotations if ann.type in self.ann_types]
ann_count = len(annotations)
filtered_items.append((item_key, annotations))
if ann_count == 0:
stats.items_missing_annotation.append(item_key)
stats.total_ann_count += ann_count
caption_check = deepcopy(stats.caption_columns)
for ann in annotations:
if ann.type == AnnotationType.caption:
caption_ = ann.caption
for cat in stats.caption_columns:
if caption_.startswith(cat):
stats.defined_captions[cat] += 1
caption_ = caption_.split(cat + ":")[-1]
caption_check.remove(cat)
self._check_contain_redundancies(caption_, stats, cat, item_key)
else:
label_name = stats.label_categories[ann.label].name
stats.defined_labels[label_name] += 1
label_name = label_name.split(":")[0]
label_check[label_name] += 1
for cap in caption_check:
stats.empty_captions[cap].append(item_key)
for label_col, v in label_check.items():
if v == 0:
stats.empty_labels[label_col].append(item_key)
return stats.to_dict(), filtered_items
def _check_contain_redundancies(self, text, stats, column, item_key):
if column not in stats.redundancies.keys():
return
import re
try:
stop_words = set(stopwords.words("english")) # TODO
except LookupError:
import nltk
nltk.download("stopwords")
stop_words = set(stopwords.words("english")) # TODO
def contains_emoji(text):
return bool(emoji_pattern.search(text))
def contains_html_tags(text):
html_pattern = re.compile(r"<.*?>")
return bool(html_pattern.search(text))
def contains_url(text):
url_pattern = re.compile(r"http\S+|www\S+|https\S+")
return bool(url_pattern.search(text))
def contains_stopword(text):
return any(c in stop_words for c in str(text).lower())
redun_stats = stats.redundancies[column]
if contains_emoji(text):
redun_stats["emoji"].append(item_key)
if contains_html_tags(text):
redun_stats["html"].append(item_key)
if contains_url(text):
redun_stats["url"].append(item_key)
if contains_stopword(text):
redun_stats["stopword"].append(item_key)
def _compute_prop_dist(self, caption_columns, stats):
dist_by_caption = stats["distribution_in_caption"]
dist_in_item = stats["distribution_in_dataset_item"]
for item_key, annotations in self.items:
ann_count = len(annotations)
dist_in_item[item_key] = ann_count
for ann in annotations:
if ann.type == AnnotationType.caption:
caption_ = ann.caption
for cat_name, type_ in caption_columns:
if caption_.startswith(cat_name):
caption_value = type_(caption_.split(f"{cat_name}:")[-1])
dist_by_caption[cat_name]["value"]["distribution"].append(caption_value)
def _compute_prop_stats_from_dist(self, dist_by_caption):
for stats in dist_by_caption.values():
prop_stats = list(stats.values())[0]
prop_dist = prop_stats.pop("distribution", [])
if prop_dist:
prop_stats["mean"] = np.mean(prop_dist)
prop_stats["stdev"] = np.std(prop_dist)
prop_stats["min"] = np.min(prop_dist)
prop_stats["max"] = np.max(prop_dist)
prop_stats["median"] = np.median(prop_dist)
counts, bins = np.histogram(prop_dist)
prop_stats["histogram"]["bins"] = bins.tolist()
prop_stats["histogram"]["counts"] = counts.tolist()
# Calculate Q1 (25th percentile) and Q3 (75th percentile)
Q1, Q3 = np.percentile(prop_dist, [25, 75])
IQR = Q3 - Q1
# Calculate the acceptable range
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
prop_stats["outlier"] = (lower_bound, upper_bound)
def _compute_far_from_mean(self, prop_stats, val, item_key):
def _far_from_mean(val, mean, stdev):
thr = self.far_from_mean_thr
return val > mean + (thr * stdev) or val < mean - (thr * stdev)
mean = prop_stats["mean"]
stdev = prop_stats["stdev"]
if _far_from_mean(val, mean, stdev):
prop_stats["items_far_from_mean"][item_key] = val
def _compute_outlier(self, prop_stats, val, item_key):
lower_bound, upper_bound = prop_stats["outlier"]
if (val < lower_bound) | (val > upper_bound):
prop_stats["items_outlier"][item_key] = val
[docs]
def compute_statistics(self, dataset):
"""
Computes statistics of the tabular dataset.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
stats, filtered_items = self._compute_common_statistics(dataset)
self.items = filtered_items
num_caption_columns = [
(cat.name, cat.dtype)
for cat in dataset.categories().get(AnnotationType.caption, TabularCategories())
if cat.dtype in [int, float]
]
stats["distribution_in_caption"] = {
cap: deepcopy(self.value_template) for cap, _ in num_caption_columns
}
stats["distribution_in_dataset_item"] = {}
# Collect property distribution
self._compute_prop_dist(num_caption_columns, stats)
# Compute property statistics from distribution
dist_by_caption = stats["distribution_in_caption"]
self._compute_prop_stats_from_dist(dist_by_caption)
def _update_captions_far_from_mean_outlier(caption_columns, item_key, ann):
for col, type_ in caption_columns:
prop_stats = dist_by_caption[col]["value"]
if ann.caption.startswith(col):
val = type_(ann.caption.split(f"{col}:")[-1])
self._compute_far_from_mean(prop_stats, val, item_key)
self._compute_outlier(prop_stats, val, item_key)
# Compute far_from_mean and outlier from property
for item_key, annotations in self.items:
for ann in annotations:
if ann.type == AnnotationType.caption:
_update_captions_far_from_mean_outlier(num_caption_columns, item_key, ann)
return stats
[docs]
def generate_reports(self, stats):
"""
Validates the dataset for classification tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats: Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = []
# report for dataset
reports += self._check_missing_label_categories(stats)
# report for item
reports += self._check_missing_annotation(stats)
# report for label
reports += self._check_few_samples_in_label(stats)
reports += self._check_imbalanced_labels(stats)
# report for caption
reports += self._check_few_samples_in_caption(stats)
reports += self._check_redundancies_in_caption(stats)
reports += self._check_imbalanced_captions(stats)
# report for missing value
reports += self._check_broken_annotation(stats)
reports += self._check_empty_label(stats)
reports += self._check_empty_caption(stats)
dist_by_caption = stats["distribution_in_caption"]
for caption, caption_stats in dist_by_caption.items():
reports += self._check_far_from_caption_mean(caption, caption_stats)
reports += self._check_caption_outliers(caption, caption_stats)
reports += self._check_imbalanced_dist_in_caption(caption, caption_stats)
return reports
def _check_broken_annotation(self, stats):
validation_reports = []
items_broken = stats["items_broken_annotation"]
for item_id, item_subset in items_broken:
validation_reports += self._generate_validation_report(
BrokenAnnotation, Severity.warning, item_id, item_subset, self.str_ann_type
)
return validation_reports
def _check_empty_label(self, stats):
validation_reports = []
empty_label_dist = stats["label_distribution"]["empty_labels"]
for label_name, label_stats in empty_label_dist.items():
for item_id, item_subset in label_stats["items_with_empty_label"]:
details = (item_subset, label_name)
validation_reports += self._generate_validation_report(
EmptyLabel, Severity.warning, item_id, *details
)
return validation_reports
def _check_empty_caption(self, stats):
validation_reports = []
empty_caption_dist = stats["caption_distribution"]["empty_captions"]
for caption_name, caption_stats in empty_caption_dist.items():
for item_id, item_subset in caption_stats["items_with_empty_caption"]:
details = (item_subset, caption_name)
validation_reports += self._generate_validation_report(
EmptyCaption, Severity.warning, item_id, *details
)
return validation_reports
def _check_few_samples_in_caption(self, stats):
validation_reports = []
thr = self.few_samples_thr
defined_caption_dist = stats["caption_distribution"]["defined_captions"]
captions_with_few_samples = [
(caption_name, count)
for caption_name, count in defined_caption_dist.items()
if 0 < count <= thr
]
for caption_name, count in captions_with_few_samples:
validation_reports += self._generate_validation_report(
FewSamplesInCaption, Severity.info, caption_name, count
)
return validation_reports
def _check_far_from_caption_mean(self, caption_name, caption_stats):
prop_stats = list(caption_stats.values())[0]
if prop_stats["mean"] is not None:
mean = round(prop_stats["mean"], 2)
stdev = prop_stats["stdev"]
upper_bound = mean + (self.far_from_mean_thr * stdev)
lower_bound = mean - (self.far_from_mean_thr * stdev)
validation_reports = []
items_far_from_mean = prop_stats["items_far_from_mean"]
for item_dets, val in items_far_from_mean.items():
item_id, item_subset = item_dets
val = round(val, 2)
details = (
item_subset,
caption_name,
mean,
upper_bound,
lower_bound,
val,
)
validation_reports += self._generate_validation_report(
FarFromCaptionMean, Severity.info, item_id, *details
)
return validation_reports
def _check_caption_outliers(self, caption_name, caption_stats):
prop_stats = list(caption_stats.values())[0]
lower_bound, upper_bound = prop_stats["outlier"]
items_outlier = prop_stats["items_outlier"]
validation_reports = []
for item_dets, val in items_outlier.items():
item_id, item_subset = item_dets
val = round(val, 2)
details = (
item_subset,
caption_name,
lower_bound,
upper_bound,
val,
)
validation_reports += self._generate_validation_report(
OutlierInCaption, Severity.info, item_id, *details
)
return validation_reports
def _check_redundancies_in_caption(self, stats):
validation_reports = []
redundancies_in_caption_dist = stats["caption_distribution"]["redundancies"]
captions_with_redundancies = []
for cap_column, cap_stats in redundancies_in_caption_dist.items():
for redundancy_type, items in cap_stats.items():
if 0 < items["count"]:
captions_with_redundancies.append((cap_column, redundancy_type, items["count"]))
for cap_column, redundancy_type, count in captions_with_redundancies:
validation_reports += self._generate_validation_report(
RedundanciesInCaption, Severity.info, cap_column, redundancy_type, count
)
return validation_reports
def _check_imbalanced_captions(self, stats):
validation_reports = []
thr = self.imbalance_ratio_thr
defined_caption_dist = stats["caption_distribution"]["defined_captions"]
count_by_caption_labels = [count for _, count in defined_caption_dist.items()]
if len(defined_caption_dist) == 0:
return validation_reports
count_max = np.max(count_by_caption_labels)
count_min = np.min(count_by_caption_labels)
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedCaptions, Severity.info
)
return validation_reports
def _check_imbalanced_dist_in_caption(self, caption_name, caption_stats):
validation_reports = []
thr = self.dominance_thr
topk_ratio = self.topk_bins_ratio
for prop, prop_stats in caption_stats.items():
value_counts = prop_stats["histogram"]["counts"]
n_bucket = len(value_counts)
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))
if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio >= thr:
validation_reports += self._generate_validation_report(
ImbalancedDistInCaption, Severity.info, caption_name
)
return validation_reports