# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
from copy import deepcopy
import numpy as np
from datumaro.components.annotation import AnnotationType, GroupType, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.errors import (
AttributeDefinedButNotFound,
FarFromAttrMean,
FarFromLabelMean,
FewSamplesInAttribute,
FewSamplesInLabel,
ImbalancedAttribute,
ImbalancedDistInAttribute,
ImbalancedDistInLabel,
ImbalancedLabels,
InvalidValue,
LabelDefinedButNotFound,
MissingAnnotation,
MissingAttribute,
MissingLabelCategories,
MultiLabelAnnotations,
NegativeLength,
OnlyOneAttributeValue,
OnlyOneLabel,
UndefinedAttribute,
UndefinedLabel,
)
from datumaro.components.validator import Severity, TaskType, Validator
from datumaro.util import parse_str_enum_value
DEFAULT_LABEL_GROUP = "default"
class _TaskValidator(Validator, CliPlugin):
DEFAULT_FEW_SAMPLES_THR = 1
DEFAULT_IMBALANCE_RATIO_THR = 50
DEFAULT_FAR_FROM_MEAN_THR = 5
DEFAULT_DOMINANCE_RATIO_THR = 0.8
DEFAULT_TOPK_BINS = 0.1
# statistics templates
numerical_stat_template = {
"items_far_from_mean": {},
"mean": None,
"stdev": None,
"min": None,
"max": None,
"median": None,
"histogram": {
"bins": [],
"counts": [],
},
"distribution": [],
}
"""
A base class for task-specific validators.
Attributes
----------
task_type : str or TaskType
task type (ie. classification, detection, segmentation)
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
"-fs",
"--few-samples-thr",
default=cls.DEFAULT_FEW_SAMPLES_THR,
type=int,
help="Threshold for giving a warning for minimum number of "
"samples per class (default: %(default)s)",
)
parser.add_argument(
"-ir",
"--imbalance-ratio-thr",
default=cls.DEFAULT_IMBALANCE_RATIO_THR,
type=int,
help="Threshold for giving data imbalance warning. "
"IR(imbalance ratio) = majority/minority "
"(default: %(default)s)",
)
parser.add_argument(
"-m",
"--far-from-mean-thr",
default=cls.DEFAULT_FAR_FROM_MEAN_THR,
type=float,
help="Threshold for giving a warning that data is far from mean. "
"A constant used to define mean +/- k * standard deviation "
"(default: %(default)s)",
)
parser.add_argument(
"-dr",
"--dominance-ratio-thr",
default=cls.DEFAULT_DOMINANCE_RATIO_THR,
type=float,
help="Threshold for giving a warning for bounding box imbalance. "
"Dominace_ratio = ratio of Top-k bin to total in histogram "
"(default: %(default)s)",
)
parser.add_argument(
"-k",
"--topk-bins",
default=cls.DEFAULT_TOPK_BINS,
type=float,
help="Ratio of bins with the highest number of data "
"to total bins in the histogram. A value in the range [0, 1] "
"(default: %(default)s)",
)
return parser
def __init__(
self,
task_type,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
"""
Validator
Parameters
---------------
few_samples_thr: int
minimum number of samples per class
warn user when samples per class is less than threshold
imbalance_ratio_thr: int
ratio of majority attribute to minority attribute
warn user when annotations are unevenly distributed
far_from_mean_thr: float
constant used to define mean +/- m * stddev
warn user when there are too big or small values
dominance_ratio_thr: float
ratio of Top-k bin to total
warn user when dominance ratio is over threshold
topk_bins: float
ratio of selected bins with most item number to total bins
warn user when values are not evenly distributed
"""
self.task_type = parse_str_enum_value(task_type, TaskType, default=TaskType.classification)
if self.task_type == TaskType.classification:
self.ann_types = {AnnotationType.label}
self.str_ann_type = "label"
elif self.task_type == TaskType.detection:
self.ann_types = {AnnotationType.bbox}
self.str_ann_type = "bounding box"
elif self.task_type == TaskType.segmentation:
self.ann_types = {AnnotationType.mask, AnnotationType.polygon, AnnotationType.ellipse}
self.str_ann_type = "mask or polygon or ellipse"
if few_samples_thr is None:
few_samples_thr = self.DEFAULT_FEW_SAMPLES_THR
if imbalance_ratio_thr is None:
imbalance_ratio_thr = self.DEFAULT_IMBALANCE_RATIO_THR
if far_from_mean_thr is None:
far_from_mean_thr = self.DEFAULT_FAR_FROM_MEAN_THR
if dominance_ratio_thr is None:
dominance_ratio_thr = self.DEFAULT_DOMINANCE_RATIO_THR
if topk_bins is None:
topk_bins = self.DEFAULT_TOPK_BINS
self.few_samples_thr = few_samples_thr
self.imbalance_ratio_thr = imbalance_ratio_thr
self.far_from_mean_thr = far_from_mean_thr
self.dominance_thr = dominance_ratio_thr
self.topk_bins_ratio = topk_bins
def _compute_common_statistics(self, dataset):
defined_attr_template = {"items_missing_attribute": [], "distribution": {}}
undefined_attr_template = {"items_with_undefined_attr": [], "distribution": {}}
undefined_label_template = {
"count": 0,
"items_with_undefined_label": [],
}
stats = {
"label_distribution": {
"defined_labels": {},
"undefined_labels": {},
},
"attribute_distribution": {"defined_attributes": {}, "undefined_attributes": {}},
}
stats["total_ann_count"] = 0
stats["items_missing_annotation"] = []
label_dist = stats["label_distribution"]
defined_label_dist = label_dist["defined_labels"]
undefined_label_dist = label_dist["undefined_labels"]
attr_dist = stats["attribute_distribution"]
defined_attr_dist = attr_dist["defined_attributes"]
undefined_attr_dist = attr_dist["undefined_attributes"]
label_categories = dataset.categories().get(AnnotationType.label, LabelCategories())
base_valid_attrs = label_categories.attributes
for category in label_categories:
defined_label_dist[category.name] = 0
filtered_anns = []
for item in dataset:
item_key = (item.id, item.subset)
annotations = []
for ann in item.annotations:
if ann.type in self.ann_types:
annotations.append(ann)
ann_count = len(annotations)
filtered_anns.append((item_key, annotations))
if ann_count == 0:
stats["items_missing_annotation"].append(item_key)
stats["total_ann_count"] += ann_count
for ann in annotations:
if not 0 <= ann.label < len(label_categories):
label_name = ann.label
label_stats = undefined_label_dist.setdefault(
ann.label, deepcopy(undefined_label_template)
)
label_stats["items_with_undefined_label"].append(item_key)
label_stats["count"] += 1
valid_attrs = set()
missing_attrs = set()
else:
label_name = label_categories[ann.label].name
defined_label_dist[label_name] += 1
defined_attr_stats = defined_attr_dist.setdefault(label_name, {})
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
ann_attrs = getattr(ann, "attributes", {}).keys()
missing_attrs = valid_attrs.difference(ann_attrs)
for attr in valid_attrs:
defined_attr_stats.setdefault(attr, deepcopy(defined_attr_template))
for attr in missing_attrs:
attr_dets = defined_attr_stats[attr]
attr_dets["items_missing_attribute"].append(item_key)
for attr, value in ann.attributes.items():
if attr not in valid_attrs:
undefined_attr_stats = undefined_attr_dist.setdefault(label_name, {})
attr_dets = undefined_attr_stats.setdefault(
attr, deepcopy(undefined_attr_template)
)
attr_dets["items_with_undefined_attr"].append(item_key)
else:
attr_dets = defined_attr_stats[attr]
attr_dets["distribution"].setdefault(str(value), 0)
attr_dets["distribution"][str(value)] += 1
return stats, filtered_anns
def _generate_common_reports(self, stats):
"""
Validates the dataset for classification tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats: Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = []
# report for dataset
reports += self._check_missing_label_categories(stats)
# report for item
reports += self._check_missing_annotation(stats)
# report for label
reports += self._check_undefined_label(stats)
reports += self._check_label_defined_but_not_found(stats)
reports += self._check_only_one_label(stats)
reports += self._check_few_samples_in_label(stats)
reports += self._check_imbalanced_labels(stats)
# report for attributes
attr_dist = stats["attribute_distribution"]
defined_attr_dist = attr_dist["defined_attributes"]
undefined_attr_dist = attr_dist["undefined_attributes"]
defined_labels = defined_attr_dist.keys()
for label_name in defined_labels:
attr_stats = defined_attr_dist[label_name]
reports += self._check_attribute_defined_but_not_found(label_name, attr_stats)
for attr_name, attr_dets in attr_stats.items():
reports += self._check_missing_attribute(label_name, attr_name, attr_dets)
reports += self._check_only_one_attribute(label_name, attr_name, attr_dets)
reports += self._check_few_samples_in_attribute(label_name, attr_name, attr_dets)
reports += self._check_imbalanced_attribute(label_name, attr_name, attr_dets)
for label_name, attr_stats in undefined_attr_dist.items():
for attr_name, attr_dets in attr_stats.items():
reports += self._check_undefined_attribute(label_name, attr_name, attr_dets)
return reports
def _generate_validation_report(self, error, *args, **kwargs):
return [error(*args, **kwargs)]
def _check_missing_label_categories(self, stats):
validation_reports = []
if len(stats["label_distribution"]["defined_labels"]) == 0:
validation_reports += self._generate_validation_report(
MissingLabelCategories, Severity.error
)
return validation_reports
def _check_missing_annotation(self, stats):
validation_reports = []
items_missing = stats["items_missing_annotation"]
for item_id, item_subset in items_missing:
validation_reports += self._generate_validation_report(
MissingAnnotation, Severity.warning, item_id, item_subset, self.str_ann_type
)
return validation_reports
def _check_missing_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
items_missing_attr = attr_dets["items_missing_attribute"]
for item_id, item_subset in items_missing_attr:
details = (item_subset, label_name, attr_name)
validation_reports += self._generate_validation_report(
MissingAttribute, Severity.warning, item_id, *details
)
return validation_reports
def _check_undefined_label(self, stats):
validation_reports = []
undefined_label_dist = stats["label_distribution"]["undefined_labels"]
for label_name, label_stats in undefined_label_dist.items():
for item_id, item_subset in label_stats["items_with_undefined_label"]:
details = (item_subset, label_name)
validation_reports += self._generate_validation_report(
UndefinedLabel, Severity.error, item_id, *details
)
return validation_reports
def _check_undefined_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
items_with_undefined_attr = attr_dets["items_with_undefined_attr"]
for item_id, item_subset in items_with_undefined_attr:
details = (item_subset, label_name, attr_name)
validation_reports += self._generate_validation_report(
UndefinedAttribute, Severity.error, item_id, *details
)
return validation_reports
def _check_label_defined_but_not_found(self, stats):
validation_reports = []
count_by_defined_labels = stats["label_distribution"]["defined_labels"]
labels_not_found = [
label_name for label_name, count in count_by_defined_labels.items() if count == 0
]
for label_name in labels_not_found:
validation_reports += self._generate_validation_report(
LabelDefinedButNotFound, Severity.warning, label_name
)
return validation_reports
def _check_attribute_defined_but_not_found(self, label_name, attr_stats):
validation_reports = []
attrs_not_found = [
attr_name
for attr_name, attr_dets in attr_stats.items()
if len(attr_dets["distribution"]) == 0
]
for attr_name in attrs_not_found:
details = (label_name, attr_name)
validation_reports += self._generate_validation_report(
AttributeDefinedButNotFound, Severity.warning, *details
)
return validation_reports
def _check_only_one_label(self, stats):
validation_reports = []
count_by_defined_labels = stats["label_distribution"]["defined_labels"]
labels_found = [
label_name for label_name, count in count_by_defined_labels.items() if count > 0
]
if len(labels_found) == 1:
validation_reports += self._generate_validation_report(
OnlyOneLabel, Severity.info, labels_found[0]
)
return validation_reports
def _check_only_one_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
values = list(attr_dets["distribution"].keys())
if len(values) == 1:
details = (label_name, attr_name, values[0])
validation_reports += self._generate_validation_report(
OnlyOneAttributeValue, Severity.info, *details
)
return validation_reports
def _check_few_samples_in_label(self, stats):
validation_reports = []
thr = self.few_samples_thr
defined_label_dist = stats["label_distribution"]["defined_labels"]
labels_with_few_samples = [
(label_name, count)
for label_name, count in defined_label_dist.items()
if 0 < count <= thr
]
for label_name, count in labels_with_few_samples:
validation_reports += self._generate_validation_report(
FewSamplesInLabel, Severity.info, label_name, count
)
return validation_reports
def _check_few_samples_in_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
thr = self.few_samples_thr
attr_values_with_few_samples = [
(attr_value, count)
for attr_value, count in attr_dets["distribution"].items()
if count <= thr
]
for attr_value, count in attr_values_with_few_samples:
details = (label_name, attr_name, attr_value, count)
validation_reports += self._generate_validation_report(
FewSamplesInAttribute, Severity.info, *details
)
return validation_reports
def _check_imbalanced_labels(self, stats):
validation_reports = []
thr = self.imbalance_ratio_thr
defined_label_dist = stats["label_distribution"]["defined_labels"]
count_by_defined_labels = [count for label, count in defined_label_dist.items()]
if len(count_by_defined_labels) == 0:
return validation_reports
count_max = np.max(count_by_defined_labels)
count_min = np.min(count_by_defined_labels)
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(ImbalancedLabels, Severity.info)
return validation_reports
def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets):
validation_reports = []
thr = self.imbalance_ratio_thr
count_by_defined_attr = list(attr_dets["distribution"].values())
if len(count_by_defined_attr) == 0:
return validation_reports
count_max = np.max(count_by_defined_attr)
count_min = np.min(count_by_defined_attr)
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedAttribute, Severity.info, label_name, attr_name
)
return validation_reports
[docs]
class ClassificationValidator(_TaskValidator):
"""
A specific validator class for classification task.
"""
def __init__(
self,
task_type=TaskType.classification,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
super().__init__(
task_type=task_type,
few_samples_thr=few_samples_thr,
imbalance_ratio_thr=imbalance_ratio_thr,
far_from_mean_thr=far_from_mean_thr,
dominance_ratio_thr=dominance_ratio_thr,
topk_bins=topk_bins,
)
[docs]
def compute_statistics(self, dataset):
"""
Computes statistics of the dataset for the classification task.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
stats, filtered_anns = self._compute_common_statistics(dataset)
label_cat = dataset.categories()[AnnotationType.label]
label_groups = label_cat.label_groups
label_name_to_group = {}
for label_group in label_groups:
for idx, label_name in enumerate(label_group.labels):
if label_group.group_type == GroupType.EXCLUSIVE:
label_name_to_group[label_name] = label_group.name
else:
label_name_to_group[label_name] = label_group.name + f"_{idx}"
undefined_label_name = list(stats["label_distribution"]["undefined_labels"].keys())
stats["items_with_multiple_labels"] = []
for item_key, anns in filtered_anns:
occupied_groups = set()
for ann in anns:
if ann.label in undefined_label_name:
continue
label_name = label_cat[ann.label].name
label_group = label_name_to_group.get(label_name, DEFAULT_LABEL_GROUP)
if label_group in occupied_groups:
stats["items_with_multiple_labels"].append(item_key)
break
occupied_groups.add(label_group)
return stats
[docs]
def generate_reports(self, stats):
"""
Validates the dataset for classification tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats: Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = self._generate_common_reports(stats)
reports += self._check_multi_label_annotations(stats)
return reports
def _check_multi_label_annotations(self, stats):
validation_reports = []
items_with_multiple_labels = stats["items_with_multiple_labels"]
for item_id, item_subset in items_with_multiple_labels:
validation_reports += self._generate_validation_report(
MultiLabelAnnotations, Severity.error, item_id, item_subset
)
return validation_reports
[docs]
class DetectionValidator(_TaskValidator):
"""
A specific validator class for detection task.
"""
def __init__(
self,
task_type=TaskType.detection,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
super().__init__(
task_type=task_type,
few_samples_thr=few_samples_thr,
imbalance_ratio_thr=imbalance_ratio_thr,
far_from_mean_thr=far_from_mean_thr,
dominance_ratio_thr=dominance_ratio_thr,
topk_bins=topk_bins,
)
self.point_template = {
"width": deepcopy(self.numerical_stat_template),
"height": deepcopy(self.numerical_stat_template),
"area(wxh)": deepcopy(self.numerical_stat_template),
"ratio(w/h)": deepcopy(self.numerical_stat_template),
"short": deepcopy(self.numerical_stat_template),
"long": deepcopy(self.numerical_stat_template),
}
[docs]
def compute_statistics(self, dataset):
"""
Computes statistics of the dataset for the detection task.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
stats, filtered_items = self._compute_common_statistics(dataset)
stats["items_with_negative_length"] = {}
stats["items_with_invalid_value"] = {}
stats["point_distribution_in_label"] = {}
stats["point_distribution_in_attribute"] = {}
stats["point_distribution_in_dataset_item"] = {}
self.items = filtered_items
def _generate_ann_bbox_info(_x, _y, _w, _h, area, ratio, _short, _long):
return {
"x": _x,
"y": _y,
"width": _w,
"height": _h,
"area(wxh)": area,
"ratio(w/h)": ratio,
"short": _short,
"long": _long,
}
def _update_bbox_stats_by_label(item_key, ann, bbox_label_stats):
bbox_has_error = False
_x, _y, _w, _h = ann.get_bbox()
area = ann.get_area()
if _h != 0 and _h != float("inf"):
ratio = _w / _h
else:
ratio = float("nan")
_short = _w if _w < _h else _h
_long = _w if _w > _h else _h
ann_bbox_info = _generate_ann_bbox_info(_x, _y, _w, _h, area, ratio, _short, _long)
items_w_invalid_val = stats["items_with_invalid_value"]
for prop, val in ann_bbox_info.items():
if val == float("inf") or np.isnan(val):
bbox_has_error = True
anns_w_invalid_val = items_w_invalid_val.setdefault(item_key, {})
invalid_props = anns_w_invalid_val.setdefault(ann.id, [])
invalid_props.append(prop)
items_w_neg_len = stats["items_with_negative_length"]
for prop in ["width", "height"]:
val = ann_bbox_info[prop]
if val < 1:
bbox_has_error = True
anns_w_neg_len = items_w_neg_len.setdefault(item_key, {})
neg_props = anns_w_neg_len.setdefault(ann.id, {})
neg_props[prop] = val
if not bbox_has_error:
ann_bbox_info.pop("x")
ann_bbox_info.pop("y")
self._update_prop_distributions(ann_bbox_info, bbox_label_stats)
return ann_bbox_info, bbox_has_error
# Collect property distribution
label_categories = dataset.categories().get(AnnotationType.label, LabelCategories())
self._compute_prop_dist(label_categories, stats, _update_bbox_stats_by_label)
# Compute property statistics from distribution
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr)
def _is_valid_bbox(item_key, ann):
has_defined_label = 0 <= ann.label < len(label_categories)
if not has_defined_label:
return False
bbox_has_neg_len = ann.id in stats["items_with_negative_length"].get(item_key, {})
bbox_has_invalid_val = ann.id in stats["items_with_invalid_value"].get(item_key, {})
return not (bbox_has_neg_len or bbox_has_invalid_val)
def _update_bbox_props_far_from_mean(item_key, ann):
base_valid_attrs = label_categories.attributes
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
label_name = label_categories[ann.label].name
bbox_label_stats = dist_by_label[label_name]
_x, _y, _w, _h = ann.get_bbox()
area = ann.get_area()
ratio = _w / _h
_short = _w if _w < _h else _h
_long = _w if _w > _h else _h
ann_bbox_info = _generate_ann_bbox_info(_x, _y, _w, _h, area, ratio, _short, _long)
ann_bbox_info.pop("x")
ann_bbox_info.pop("y")
for prop, val in ann_bbox_info.items():
prop_stats = bbox_label_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
for attr, value in ann.attributes.items():
if attr in valid_attrs:
bbox_attr_stats = dist_by_attr[label_name][attr]
bbox_val_stats = bbox_attr_stats[str(value)]
for prop, val in ann_bbox_info.items():
prop_stats = bbox_val_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
# Compute far_from_mean from property
for item_key, annotations in self.items:
for ann in annotations:
if _is_valid_bbox(item_key, ann):
_update_bbox_props_far_from_mean(item_key, ann)
return stats
[docs]
def generate_reports(self, stats):
"""
Validates the dataset for detection tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats : Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = self._generate_common_reports(stats)
reports += self._check_negative_length(stats)
reports += self._check_invalid_value(stats)
defined_attr_dist = stats["attribute_distribution"]["defined_attributes"]
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
defined_labels = defined_attr_dist.keys()
for label_name in defined_labels:
bbox_label_stats = dist_by_label[label_name]
bbox_attr_label = dist_by_attr.get(label_name, {})
reports += self._check_far_from_label_mean(label_name, bbox_label_stats)
reports += self._check_imbalanced_dist_in_label(label_name, bbox_label_stats)
for attr_name, bbox_attr_stats in bbox_attr_label.items():
reports += self._check_far_from_attr_mean(label_name, attr_name, bbox_attr_stats)
reports += self._check_imbalanced_dist_in_attr(
label_name, attr_name, bbox_attr_stats
)
return reports
def _update_prop_distributions(self, curr_stats, target_stats):
for prop, val in curr_stats.items():
prop_stats = target_stats[prop]
prop_stats["distribution"].append(val)
def _compute_prop_dist(self, label_categories, stats, update_stats_by_label):
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
point_dist_in_item = stats["point_distribution_in_dataset_item"]
base_valid_attrs = label_categories.attributes
for item_key, annotations in self.items:
ann_count = len(annotations)
point_dist_in_item[item_key] = ann_count
for ann in annotations:
if not 0 <= ann.label < len(label_categories):
label_name = ann.label
valid_attrs = set()
else:
label_name = label_categories[ann.label].name
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
point_label_stats = dist_by_label.setdefault(
label_name, deepcopy(self.point_template)
)
ann_point_info, _has_error = update_stats_by_label(
item_key, ann, point_label_stats
)
for attr, value in ann.attributes.items():
if attr in valid_attrs:
point_attr_label = dist_by_attr.setdefault(label_name, {})
point_attr_stats = point_attr_label.setdefault(attr, {})
point_val_stats = point_attr_stats.setdefault(
str(value), deepcopy(self.point_template)
)
if not _has_error:
self._update_prop_distributions(ann_point_info, point_val_stats)
def _compute_prop_stats_from_dist(self, dist_by_label, dist_by_attr):
for label_name, stats in dist_by_label.items():
prop_stats_list = list(stats.values())
attr_label = dist_by_attr.get(label_name, {})
for vals in attr_label.values():
for val_stats in vals.values():
prop_stats_list += list(val_stats.values())
for prop_stats in prop_stats_list:
prop_dist = prop_stats.pop("distribution", [])
if len(prop_dist) > 0:
prop_stats["mean"] = np.mean(prop_dist)
prop_stats["stdev"] = np.std(prop_dist)
prop_stats["min"] = np.min(prop_dist)
prop_stats["max"] = np.max(prop_dist)
prop_stats["median"] = np.median(prop_dist)
counts, bins = np.histogram(prop_dist)
prop_stats["histogram"]["bins"] = bins.tolist()
prop_stats["histogram"]["counts"] = counts.tolist()
def _compute_far_from_mean(self, prop_stats, val, item_key, ann):
def _far_from_mean(val, mean, stdev):
thr = self.far_from_mean_thr
return val > mean + (thr * stdev) or val < mean - (thr * stdev)
mean = prop_stats["mean"]
stdev = prop_stats["stdev"]
if _far_from_mean(val, mean, stdev):
items_far_from_mean = prop_stats["items_far_from_mean"]
far_from_mean = items_far_from_mean.setdefault(item_key, {})
far_from_mean[ann.id] = val
def _check_negative_length(self, stats):
validation_reports = []
items_w_neg_len = stats["items_with_negative_length"]
for item_dets, anns_w_neg_len in items_w_neg_len.items():
item_id, item_subset = item_dets
for ann_id, props in anns_w_neg_len.items():
for prop, val in props.items():
val = round(val, 2)
details = (item_subset, ann_id, f"{self.str_ann_type} {prop}", val)
validation_reports += self._generate_validation_report(
NegativeLength, Severity.error, item_id, *details
)
return validation_reports
def _check_invalid_value(self, stats):
validation_reports = []
items_w_invalid_val = stats["items_with_invalid_value"]
for item_dets, anns_w_invalid_val in items_w_invalid_val.items():
item_id, item_subset = item_dets
for ann_id, props in anns_w_invalid_val.items():
for prop in props:
details = (item_subset, ann_id, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
InvalidValue, Severity.error, item_id, *details
)
return validation_reports
def _check_imbalanced_dist_in_label(self, label_name, label_stats):
validation_reports = []
thr = self.dominance_thr
topk_ratio = self.topk_bins_ratio
for prop, prop_stats in label_stats.items():
value_counts = prop_stats["histogram"]["counts"]
n_bucket = len(value_counts)
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))
if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio >= thr:
details = (label_name, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInLabel, Severity.info, *details
)
return validation_reports
def _check_imbalanced_dist_in_attr(self, label_name, attr_name, attr_stats):
validation_reports = []
thr = self.dominance_thr
topk_ratio = self.topk_bins_ratio
for attr_value, value_stats in attr_stats.items():
for prop, prop_stats in value_stats.items():
value_counts = prop_stats["histogram"]["counts"]
n_bucket = len(value_counts)
if n_bucket < 2:
continue
topk = max(1, int(np.around(n_bucket * topk_ratio)))
if topk > 0:
topk_values = np.sort(value_counts)[-topk:]
ratio = np.sum(topk_values) / np.sum(value_counts)
if ratio >= thr:
details = (label_name, attr_name, attr_value, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInAttribute, Severity.info, *details
)
return validation_reports
def _check_far_from_label_mean(self, label_name, label_stats):
validation_reports = []
for prop, prop_stats in label_stats.items():
items_far_from_mean = prop_stats["items_far_from_mean"]
if prop_stats["mean"] is not None:
mean = round(prop_stats["mean"], 2)
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (
item_subset,
label_name,
ann_id,
f"{self.str_ann_type} {prop}",
mean,
val,
)
validation_reports += self._generate_validation_report(
FarFromLabelMean, Severity.warning, item_id, *details
)
return validation_reports
def _check_far_from_attr_mean(self, label_name, attr_name, attr_stats):
validation_reports = []
for attr_value, value_stats in attr_stats.items():
for prop, prop_stats in value_stats.items():
items_far_from_mean = prop_stats["items_far_from_mean"]
if prop_stats["mean"] is not None:
mean = round(prop_stats["mean"], 2)
for item_dets, anns_far in items_far_from_mean.items():
item_id, item_subset = item_dets
for ann_id, val in anns_far.items():
val = round(val, 2)
details = (
item_subset,
label_name,
ann_id,
attr_name,
attr_value,
f"{self.str_ann_type} {prop}",
mean,
val,
)
validation_reports += self._generate_validation_report(
FarFromAttrMean, Severity.warning, item_id, *details
)
return validation_reports
[docs]
class SegmentationValidator(DetectionValidator):
"""
A specific validator class for (instance) segmentation task.
"""
def __init__(
self,
task_type=TaskType.segmentation,
few_samples_thr=None,
imbalance_ratio_thr=None,
far_from_mean_thr=None,
dominance_ratio_thr=None,
topk_bins=None,
):
super().__init__(
task_type=task_type,
few_samples_thr=few_samples_thr,
imbalance_ratio_thr=imbalance_ratio_thr,
far_from_mean_thr=far_from_mean_thr,
dominance_ratio_thr=dominance_ratio_thr,
topk_bins=topk_bins,
)
self.point_template = {
"area": deepcopy(self.numerical_stat_template),
"width": deepcopy(self.numerical_stat_template),
"height": deepcopy(self.numerical_stat_template),
}
[docs]
def compute_statistics(self, dataset):
"""
Computes statistics of the dataset for the segmentation task.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
stats, filtered_items = self._compute_common_statistics(dataset)
stats["items_with_invalid_value"] = {}
stats["point_distribution_in_label"] = {}
stats["point_distribution_in_attribute"] = {}
stats["point_distribution_in_dataset_item"] = {}
self.items = filtered_items
def _generate_ann_mask_info(area, _w, _h):
return {
"area": area,
"width": _w,
"height": _h,
}
def _update_mask_stats_by_label(item_key, ann, mask_label_stats):
mask_has_error = False
_, _, _w, _h = ann.get_bbox()
# Detete the following block when #226 is resolved
# https://github.com/openvinotoolkit/datumaro/issues/226
if ann.type == AnnotationType.mask:
_w += 1
_h += 1
area = ann.get_area()
ann_mask_info = _generate_ann_mask_info(area, _w, _h)
items_w_invalid_val = stats["items_with_invalid_value"]
for prop, val in ann_mask_info.items():
if val == float("inf") or np.isnan(val):
mask_has_error = True
anns_w_invalid_val = items_w_invalid_val.setdefault(item_key, {})
invalid_props = anns_w_invalid_val.setdefault(ann.id, [])
invalid_props.append(prop)
if not mask_has_error:
self._update_prop_distributions(ann_mask_info, mask_label_stats)
return ann_mask_info, mask_has_error
# Collect property distribution
label_categories = dataset.categories().get(AnnotationType.label, LabelCategories())
self._compute_prop_dist(label_categories, stats, _update_mask_stats_by_label)
# Compute property statistics from distribution
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
self._compute_prop_stats_from_dist(dist_by_label, dist_by_attr)
def _is_valid_mask(item_key, ann):
has_defined_label = 0 <= ann.label < len(label_categories)
if not has_defined_label:
return False
mask_has_invalid_val = ann.id in stats["items_with_invalid_value"].get(item_key, {})
return not mask_has_invalid_val
def _update_mask_props_far_from_mean(item_key, ann):
base_valid_attrs = label_categories.attributes
valid_attrs = base_valid_attrs.union(label_categories[ann.label].attributes)
label_name = label_categories[ann.label].name
mask_label_stats = dist_by_label[label_name]
_, _, _w, _h = ann.get_bbox()
# Detete the following block when #226 is resolved
# https://github.com/openvinotoolkit/datumaro/issues/226
if ann.type == AnnotationType.mask:
_w += 1
_h += 1
area = ann.get_area()
ann_mask_info = _generate_ann_mask_info(area, _w, _h)
for prop, val in ann_mask_info.items():
prop_stats = mask_label_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
for attr, value in ann.attributes.items():
if attr in valid_attrs:
mask_attr_stats = dist_by_attr[label_name][attr]
mask_val_stats = mask_attr_stats[str(value)]
for prop, val in ann_mask_info.items():
prop_stats = mask_val_stats[prop]
self._compute_far_from_mean(prop_stats, val, item_key, ann)
for item_key, annotations in self.items:
for ann in annotations:
if _is_valid_mask(item_key, ann):
_update_mask_props_far_from_mean(item_key, ann)
return stats
[docs]
def generate_reports(self, stats):
"""
Validates the dataset for segmentation tasks based on its statistics.
Parameters
----------
dataset : IDataset object
stats : Dict object
Returns
-------
reports (list): List of validation reports (DatasetValidationError).
"""
reports = self._generate_common_reports(stats)
reports += self._check_invalid_value(stats)
defined_attr_dist = stats["attribute_distribution"]["defined_attributes"]
dist_by_label = stats["point_distribution_in_label"]
dist_by_attr = stats["point_distribution_in_attribute"]
defined_labels = defined_attr_dist.keys()
for label_name in defined_labels:
mask_label_stats = dist_by_label[label_name]
mask_attr_label = dist_by_attr.get(label_name, {})
reports += self._check_far_from_label_mean(label_name, mask_label_stats)
reports += self._check_imbalanced_dist_in_label(label_name, mask_label_stats)
for attr_name, mask_attr_stats in mask_attr_label.items():
reports += self._check_far_from_attr_mean(label_name, attr_name, mask_attr_stats)
reports += self._check_imbalanced_dist_in_attr(
label_name, attr_name, mask_attr_stats
)
return reports