# Copyright (C) 2020-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import errno
import os
import os.path as osp
from enum import Enum, auto
from typing import Iterable, List, Optional, Sequence, Tuple, Union
from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import DatasetImportError, InvalidAnnotationError, MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.media import Image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file
[docs]
class ImagenetTxtPath:
LABELS_FILE = "synsets.txt"
IMAGE_DIR = "images"
class _LabelsSource(Enum):
file = auto()
generate = auto()
def _parse_annotation_line(line: str) -> Tuple[str, str, Sequence[int]]:
item = line.split('"')
if 1 < len(item):
if len(item) == 3:
item_id = item[1]
item = item[2].split()
image = item_id + item[0]
label_ids = [int(id) for id in item[1:]]
else:
raise InvalidAnnotationError(
"Line %s: unexpected number " "of quotes in filename" % line
)
else:
item = line.split()
item_id = osp.splitext(item[0])[0]
image = item[0]
label_ids = [int(id) for id in item[1:]]
return item_id, image, label_ids
[docs]
class ImagenetTxtBase(SubsetBase):
def __init__(
self,
path: str,
*,
subset: Optional[str] = None,
ctx: Optional[ImportContext] = None,
labels: Union[Iterable[str], str] = _LabelsSource.file.name,
labels_file: str = ImagenetTxtPath.LABELS_FILE,
image_dir: Optional[str] = None,
):
if not osp.isfile(path):
raise FileNotFoundError(errno.ENOENT, "Can't find dataset file", path)
if not subset:
subset = osp.splitext(osp.basename(path))[0]
super().__init__(subset=subset, ctx=ctx)
root_dir = osp.dirname(path)
if not image_dir:
image_dir = ImagenetTxtPath.IMAGE_DIR
self.image_dir = osp.join(root_dir, image_dir)
self._generate_labels = False
if isinstance(labels, str):
labels_source = _LabelsSource[labels]
if labels_source == _LabelsSource.generate:
labels = ()
self._generate_labels = True
elif labels_source == _LabelsSource.file:
if has_meta_file(root_dir):
labels = parse_meta_file(root_dir).keys()
else:
labels = self._parse_labels(osp.join(root_dir, labels_file))
else:
assert False, "Unhandled labels source %s" % labels_source
else:
assert all(isinstance(e, str) for e in labels)
self._categories = self._load_categories(labels)
self._items = list(self._load_items(path).values())
@staticmethod
def _parse_labels(path):
with open(path, encoding="utf-8") as labels_file:
return [s.strip() for s in labels_file]
def _load_categories(self, labels):
return {AnnotationType.label: LabelCategories.from_iterable(labels)}
def _load_items(self, path):
items = {}
with open(path, encoding="utf-8") as f:
for line in f:
item_id, image, label_ids = _parse_annotation_line(line)
anno = []
label_categories = self._categories[AnnotationType.label]
for label in label_ids:
if label < 0:
raise DatasetImportError(f"Image '{item_id}': invalid label id '{label}'")
if len(label_categories) <= label:
if self._generate_labels:
while len(label_categories) <= label:
label_categories.add(f"class-{len(label_categories)}")
else:
raise DatasetImportError(
f"Image '{item_id}': unknown label id '{label}'"
)
anno.append(Label(label))
self._ann_types.add(AnnotationType.label)
items[item_id] = DatasetItem(
id=item_id,
subset=self._subset,
media=Image.from_file(path=osp.join(self.image_dir, image)),
annotations=anno,
)
return items
[docs]
class ImagenetTxtImporter(Importer, CliPlugin):
_ANNO_EXT = ".txt"
[docs]
@classmethod
def detect(cls, context: FormatDetectionContext) -> None:
annot_path = context.require_file(
f"*{cls._ANNO_EXT}", exclude_fnames=ImagenetTxtPath.LABELS_FILE
)
with context.probe_text_file(
annot_path,
"must be an ImageNet-like annotation file",
) as f:
for line in f:
_, _, label_ids = _parse_annotation_line(line)
if label_ids:
break
else:
# If there are no labels in the entire file, it's probably
# not actually an ImageNet file.
raise Exception
[docs]
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
"--labels",
choices=_LabelsSource.__members__,
default=_LabelsSource.file.name,
help="Where to get label descriptions from (use "
"'file' to load from the file specified by --labels-file; "
"'generate' to create generic ones)",
)
parser.add_argument(
"--labels-file",
default=ImagenetTxtPath.LABELS_FILE,
help="Path to the file with label descriptions (synsets.txt)",
)
return parser
[docs]
@classmethod
def find_sources_with_params(cls, path, **extra_params):
if "labels" not in extra_params or extra_params["labels"] == _LabelsSource.file.name:
labels_file_name = osp.basename(
extra_params.get("labels_file") or ImagenetTxtPath.LABELS_FILE
)
def file_filter(p):
return osp.basename(p) != labels_file_name
else:
file_filter = None
return cls._find_sources_recursive(path, ".txt", "imagenet_txt", file_filter=file_filter)
[docs]
@classmethod
def get_file_extensions(cls) -> List[str]:
return [cls._ANNO_EXT]
[docs]
class ImagenetTxtExporter(Exporter):
DEFAULT_IMAGE_EXT = ".jpg"
def _apply_impl(self):
if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image):
raise MediaTypeError("Media type is not an image")
subset_dir = self._save_dir
os.makedirs(subset_dir, exist_ok=True)
extractor = self._extractor
for subset_name, subset in self._extractor.subsets().items():
annotation_file = osp.join(subset_dir, "%s.txt" % subset_name)
labels = {}
for item in subset:
item_id = item.id
if 1 < len(item_id.split()):
item_id = '"' + item_id + '"'
item_id += self._find_image_ext(item)
labels[item_id] = set(
p.label for p in item.annotations if p.type == AnnotationType.label
)
if self._save_media and item.media:
self._save_image(item, subdir=ImagenetTxtPath.IMAGE_DIR)
annotation = ""
for item_id, item_labels in labels.items():
annotation += "%s %s\n" % (item_id, " ".join(str(l) for l in item_labels))
with open(annotation_file, "w", encoding="utf-8") as f:
f.write(annotation)
if self._save_dataset_meta:
self._save_meta_file(subset_dir)
else:
labels_file = osp.join(subset_dir, ImagenetTxtPath.LABELS_FILE)
with open(labels_file, "w", encoding="utf-8") as f:
f.writelines(
l.name + "\n"
for l in extractor.categories().get(AnnotationType.label, LabelCategories())
)