# Copyright (C) 2019-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import argparse
import logging as log
import os
import os.path as osp
from enum import Enum, auto
from datumaro.components.comparator import DistanceComparator, EqualityComparator, TableComparator
from datumaro.components.errors import ProjectNotFoundError
from datumaro.util.os_util import rmtree
from datumaro.util.scope import on_error_do, scope_add, scoped
from ..util import MultilineFormatter
from ..util.compare import DistanceCompareVisualizer
from ..util.errors import CliException
from ..util.project import generate_next_file_name, load_project, parse_full_revpath
[docs]
class ComparisonMethod(Enum):
table = auto()
equality = auto()
distance = auto()
eq_default_if = ["id", "group"] # avoid https://bugs.python.org/issue16399
[docs]
def build_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(
help="Compares two datasets",
description="""
Compares two datasets. This command has multiple forms:|n
1) %(prog)s <revpath>|n
2) %(prog)s <revpath> <revpath>|n
|n
1 - Compares the current project's main target ('project')
in the working tree with the specified dataset.|n
2 - Compares two specified datasets.|n
|n
<revpath> - either a dataset path or a revision path. The full
syntax is:|n
- Dataset paths:|n
|s|s- <dataset path>[ :<format> ]|n
- Revision paths:|n
|s|s- <project path> [ @<rev> ] [ :<target> ]|n
|s|s- <rev> [ :<target> ]|n
|s|s- <target>|n
|n
Both forms use the -p/--project as a context for plugins. It can be
useful for dataset paths in targets. When not specified, the current
project's working tree is used.|n
|n
Annotations can be matched 3 ways:|n
- by comparision table|n
- by equality checking|n
- by distance computation|n
|n
Examples:|n
- Compare two projects by distance, match boxes if IoU > 0.7,|n
|s|s|s|ssave results to Tensorboard:|n
|s|s%(prog)s other/project -o diff/ -f tensorboard --iou-thresh 0.7|n
|n
- Compare two projects for equality, exclude annotation groups |n
|s|s|s|sand the 'is_crowd' attribute from comparison:|n
|s|s%(prog)s other/project/ -if group -ia is_crowd -m equality|n
|n
- Compare two datasets, specify formats:|n
|s|s%(prog)s path/to/dataset1:voc path/to/dataset2:coco|n
|n
- Compare the current working tree and a dataset:|n
|s|s%(prog)s path/to/dataset2:coco|n
|n
- Compare a source from a previous revision and a dataset:|n
|s|s%(prog)s HEAD~2:source-2 path/to/dataset2:yolo
""",
formatter_class=MultilineFormatter,
)
formats = ", ".join(f.name for f in DistanceCompareVisualizer.OutputFormat)
comp_methods = ", ".join(m.name for m in ComparisonMethod)
def _parse_output_format(s):
try:
return DistanceCompareVisualizer.OutputFormat[s.lower()]
except KeyError:
raise argparse.ArgumentError(
"format",
message="Unknown output " "format '%s', the only available are: %s" % (s, formats),
)
def _parse_comparison_method(s):
try:
return ComparisonMethod[s.lower()]
except KeyError:
raise argparse.ArgumentError(
"method",
message="Unknown comparison "
"method '%s', the only available are: %s" % (s, comp_methods),
)
parser.add_argument("first_target", help="The first dataset revpath to be compared")
parser.add_argument(
"second_target", nargs="?", help="The second dataset revpath to be compared"
)
parser.add_argument(
"-o",
"--output-dir",
dest="dst_dir",
default=None,
help="Directory to save comparison results " "(default: generate automatically)",
)
parser.add_argument(
"-m",
"--method",
type=_parse_comparison_method,
default=ComparisonMethod.table.name,
help="Comparison method, one of {} (default: %(default)s)".format(comp_methods),
)
parser.add_argument(
"--overwrite", action="store_true", help="Overwrite existing files in the save directory"
)
parser.add_argument(
"-p",
"--project",
dest="project_dir",
help="Directory of the current project (default: current dir)",
)
parser.set_defaults(command=compare_command)
distance_parser = parser.add_argument_group("Distance comparison options")
distance_parser.add_argument(
"--iou-thresh",
default=0.5,
type=float,
help="IoU match threshold for shapes (default: %(default)s)",
)
parser.add_argument(
"-f",
"--format",
type=_parse_output_format,
default=DistanceCompareVisualizer.DEFAULT_FORMAT.name,
help="Output format, one of {} (default: %(default)s)".format(formats),
)
equality_parser = parser.add_argument_group("Equality comparison options")
equality_parser.add_argument(
"-iia", "--ignore-item-attr", action="append", help="Ignore item attribute (repeatable)"
)
equality_parser.add_argument(
"-ia", "--ignore-attr", action="append", help="Ignore annotation attribute (repeatable)"
)
equality_parser.add_argument(
"-if",
"--ignore-field",
action="append",
help="Ignore annotation field (repeatable, default: %s)" % eq_default_if,
)
equality_parser.add_argument(
"--match-images",
action="store_true",
help="Match dataset items by image pixels instead of ids",
)
equality_parser.add_argument("--all", action="store_true", help="Include matches in the output")
return parser
[docs]
def get_sensitive_args():
return {
compare_command: [
"first_target",
"second_target",
"dst_dir",
"project_dir",
],
}
[docs]
@scoped
def compare_command(args):
dst_dir = args.dst_dir
if dst_dir:
if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir):
raise CliException(
"Directory '%s' already exists " "(pass --overwrite to overwrite)" % dst_dir
)
else:
dst_dir = generate_next_file_name("compare")
dst_dir = osp.abspath(dst_dir)
if not osp.exists(dst_dir):
on_error_do(rmtree, dst_dir, ignore_errors=True)
os.makedirs(dst_dir)
project = None
try:
project = scope_add(load_project(args.project_dir))
except ProjectNotFoundError:
if args.project_dir:
raise
try:
if not args.second_target:
first_dataset = project.working_tree.make_dataset()
second_dataset, target_project = parse_full_revpath(args.first_target, project)
if target_project:
scope_add(target_project)
else:
first_dataset, target_project = parse_full_revpath(args.first_target, project)
if target_project:
scope_add(target_project)
second_dataset, target_project = parse_full_revpath(args.second_target, project)
if target_project:
scope_add(target_project)
except Exception as e:
raise CliException(str(e))
if args.method is ComparisonMethod.table:
comparator = TableComparator()
(
high_level_table,
mid_level_table,
low_level_table,
comparison_dict,
) = comparator.compare_datasets(first_dataset, second_dataset)
if args.dst_dir:
comparator.save_compare_report(
high_level_table, mid_level_table, low_level_table, comparison_dict, args.dst_dir
)
elif args.method is ComparisonMethod.equality:
if args.ignore_field:
args.ignore_field = eq_default_if
comparator = EqualityComparator(
match_images=args.match_images,
ignored_fields=args.ignore_field,
ignored_attrs=args.ignore_attr,
ignored_item_attrs=args.ignore_item_attr,
all=args.all,
)
output = comparator.compare_datasets(first_dataset, second_dataset)
if args.dst_dir:
comparator.save_compare_report(output, args.dst_dir)
elif args.method is ComparisonMethod.distance:
comparator = DistanceComparator(iou_threshold=args.iou_thresh)
with DistanceCompareVisualizer(
save_dir=dst_dir, comparator=comparator, output_format=args.format
) as visualizer:
log.info("Saving compare to '%s'" % dst_dir)
visualizer.save(first_dataset, second_dataset)
return 0