Source code for datumaro.plugins.data_formats.ade20k2020

# Copyright (C) 2020-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import errno
import glob
import logging as log
import os
import os.path as osp
import re
from functools import partial
from typing import List, Optional

import numpy as np

from datumaro.components.annotation import (
    AnnotationType,
    ExtractedMask,
    LabelCategories,
    Mask,
    Polygon,
)
from datumaro.components.dataset_base import DatasetBase, DatasetItem
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.media import Image
from datumaro.rust_api import JsonSectionPageMapper
from datumaro.util import parse_json
from datumaro.util.image import IMAGE_EXTENSIONS, find_images, lazy_image, load_image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


[docs] class Ade20k2020Path: MASK_PATTERN = re.compile( r""".+_seg | .+_parts_\d+ | instance_.+ """, re.VERBOSE, )
[docs] class Ade20k2020Base(DatasetBase): def __init__(self, path: str, *, ctx: Optional[ImportContext] = None): if not osp.isdir(path): raise NotADirectoryError(errno.ENOTDIR, "Can't find dataset directory", path) # exclude dataset meta file subsets = [subset for subset in os.listdir(path) if osp.splitext(subset)[-1] != ".json"] if len(subsets) < 1: raise FileNotFoundError(errno.ENOENT, "Can't find subsets in directory", path) super().__init__(subsets=sorted(subsets), ctx=ctx) self._path = path self._items = [] self._categories = {} if has_meta_file(self._path): self._categories = { AnnotationType.label: LabelCategories.from_iterable( parse_meta_file(self._path).keys() ) } for subset in self._subsets: self._load_items(subset) def __iter__(self): return iter(self._items)
[docs] def categories(self): return self._categories
def _load_items(self, subset): labels = self._categories.setdefault(AnnotationType.label, LabelCategories()) path = osp.join(self._path, subset) images = [i for i in find_images(path, recursive=True)] for image_path in sorted(images): item_id = osp.splitext(osp.relpath(image_path, path))[0] if Ade20k2020Path.MASK_PATTERN.fullmatch(osp.basename(item_id)): continue item_annotations = [] item_info = self._load_item_info(image_path) for item in item_info: label_idx = labels.find(item["label_name"])[0] if label_idx is None: labels.add(item["label_name"]) mask_path = osp.splitext(image_path)[0] + "_seg.png" max_part_level = max([p["part_level"] for p in item_info]) for part_level in range(max_part_level + 1): if not osp.exists(mask_path): log.warning("Can`t find part level %s mask for %s" % (part_level, image_path)) continue mask = lazy_image(mask_path, loader=self._load_class_mask) classes = { (v["class_idx"], v["label_name"]) for v in item_info if v["part_level"] == part_level } for class_idx, label_name in classes: label_id = labels.find(label_name)[0] item_annotations.append( ExtractedMask( index_mask=mask, index=class_idx, label=label_id, id=class_idx, group=class_idx, z_order=part_level, ) ) mask_path = osp.splitext(image_path)[0] + "_parts_%s.png" % (part_level + 1) for item in item_info: instance_path = osp.join(osp.dirname(image_path), item["instance_mask"]) if not osp.isfile(instance_path): log.warning("Can`t find instance mask: %s" % instance_path) continue mask = lazy_image(instance_path, loader=self._load_instance_mask) label_id = labels.find(item["label_name"])[0] instance_id = item["id"] attributes = {k: True for k in item["attributes"]} polygon_points = item["polygon_points"] item_annotations.append( Mask( label=label_id, image=partial(self._get_instance_mask, mask), id=instance_id, attributes=attributes, z_order=item["part_level"], group=instance_id, ) ) if len(item["polygon_points"]) % 2 == 0 and 3 <= len(item["polygon_points"]) // 2: item_annotations.append( Polygon( polygon_points, label=label_id, attributes=attributes, id=instance_id, z_order=item["part_level"], group=instance_id, ) ) self._items.append( DatasetItem( item_id, subset=subset, media=Image.from_file(path=image_path), annotations=item_annotations, ) ) for ann in item_annotations: self._ann_types.add(ann.type) def _load_item_info(self, path): json_path = osp.splitext(path)[0] + ".json" item_info = [] if not osp.isfile(json_path): raise FileNotFoundError( errno.ENOENT, "Can't find annotation file for image %s" % path, json_path ) with open(json_path, "r", encoding="latin-1") as f: item_objects = parse_json(f.read())["annotation"]["object"] for obj in item_objects: polygon_points = [] for x, y in zip(obj["polygon"]["x"], obj["polygon"]["y"]): polygon_points.append(x) polygon_points.append(y) attributes = obj["attributes"] if isinstance(attributes, str): attributes = [attributes] item_info.append( { "id": obj["id"], "class_idx": obj["name_ndx"], "part_level": obj["parts"]["part_level"], "occluded": int(obj["occluded"] == "yes"), "crop": obj["crop"], "label_name": obj["raw_name"], "attributes": attributes, "instance_mask": obj["instance_mask"], "polygon_points": polygon_points, } ) return item_info @staticmethod def _load_instance_mask(path): mask = load_image(path) _, instance_mask = np.unique(mask, return_inverse=True) instance_mask = instance_mask.reshape(mask.shape) return instance_mask @staticmethod def _load_class_mask(path): mask = load_image(path) mask = ((mask[:, :, 2] / 10).astype(np.int32) << 8) + mask[:, :, 1].astype(np.int32) return mask @staticmethod def _get_instance_mask(mask: lazy_image) -> np.ndarray: return mask() == 1
[docs] class Ade20k2020Importer(Importer): _ANNO_EXT = ".json"
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> None: annot_path = context.require_file(f"*/**/*{cls._ANNO_EXT}") with context.probe_text_file( annot_path, 'must be a JSON object with an "annotation" key', ): fpath = osp.join(context.root_path, annot_path) page_mapper = JsonSectionPageMapper(fpath) sections = page_mapper.sections() if "annotation" not in sections.keys(): raise Exception
[docs] @classmethod def find_sources(cls, path): for i in range(5): for i in glob.iglob(osp.join(path, *("*" * i))): if osp.splitext(i)[1].lower() in IMAGE_EXTENSIONS: return [ { "url": path, "format": Ade20k2020Base.NAME, } ] return []
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [cls._ANNO_EXT]