Source code for datumaro.plugins.data_formats.datumaro.base

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
import re
from typing import Dict, List, Optional, Set, Type

from datumaro.components.annotation import (
    NO_OBJECT_ID,
    AnnotationType,
    Bbox,
    Caption,
    Cuboid3d,
    Ellipse,
    GroupType,
    Label,
    LabelCategories,
    MaskCategories,
    Points,
    PointsCategories,
    Polygon,
    PolyLine,
    RleMask,
)
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import DatasetImportError, MediaTypeError
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image, MediaElement, MediaType, PointCloud, Video, VideoFrame
from datumaro.plugins.data_formats.datumaro.page_mapper import DatumPageMapper
from datumaro.util import parse_json_file
from datumaro.version import __version__

from .format import DATUMARO_FORMAT_VERSION, DatumaroPath

__all__ = ["DatumaroBase"]


[docs] class JsonReader: def __init__( self, path: str, subset: str, rootpath: str, images_dir: str, pcd_dir: str, video_dir: str, ctx: ImportContext, ) -> None: self._subset = subset self._rootpath = rootpath self._images_dir = images_dir self._pcd_dir = pcd_dir self._video_dir = video_dir self._videos = {} self._ctx = ctx self.ann_types = {} self._reader = self._init_reader(path) self.media_type = self._load_media_type(self._reader) self.infos = self._load_infos(self._reader) self.categories = self._load_categories(self._reader) self.items = self._load_items(self._reader) def _init_reader(self, path: str): return parse_json_file(path) @staticmethod def _load_media_type(parsed) -> Type[MediaElement]: media_type = parsed.get("media_type", MediaType.IMAGE) return MediaType(media_type).media @staticmethod def _load_ann_types(parsed) -> Set[AnnotationType]: return parsed.get("ann_types", set()) @staticmethod def _load_infos(parsed) -> Dict: return parsed.get("infos", {}) @staticmethod def _load_categories(parsed) -> Dict: categories = {} parsed_label_cat = parsed["categories"].get(AnnotationType.label.name) if parsed_label_cat: label_categories = LabelCategories(attributes=parsed_label_cat.get("attributes", [])) for item in parsed_label_cat["labels"]: label_categories.add( item["name"], parent=item["parent"], attributes=item.get("attributes", []), ) for item in parsed_label_cat.get("label_groups", []): label_categories.add_label_group( item["name"], labels=item["labels"], group_type=GroupType.from_str(item["group_type"]), ) categories[AnnotationType.label] = label_categories parsed_mask_cat = parsed["categories"].get(AnnotationType.mask.name) if parsed_mask_cat: colormap = {} for item in parsed_mask_cat["colormap"]: colormap[int(item["label_id"])] = (item["r"], item["g"], item["b"]) mask_categories = MaskCategories(colormap=colormap) categories[AnnotationType.mask] = mask_categories parsed_points_cat = parsed["categories"].get(AnnotationType.points.name) if parsed_points_cat: point_categories = PointsCategories() for item in parsed_points_cat["items"]: point_categories.add(int(item["label_id"]), item["labels"], joints=item["joints"]) categories[AnnotationType.points] = point_categories return categories def _load_items(self, parsed) -> List: item_descs: List = parsed["items"] pbar = self._ctx.progress_reporter def _gen(): # Reverse the list to pop from the front of it item_descs.reverse() while item_descs: yield item_descs.pop() items = [] ann_types = set() for item_desc in pbar.iter( _gen(), desc=f"Importing '{self._subset}'", total=len(item_descs) ): item = self._parse_item(item_desc) if item is not None: items.append(item) for ann in item.annotations: ann_types.add(ann.type) self.ann_types = ann_types return items def _parse_item(self, item_desc: Dict) -> Optional[DatasetItem]: try: item_id = item_desc["id"] media = None image_info = item_desc.get("image") if image_info: image_filename = image_info.get("path") or item_id + DatumaroPath.IMAGE_EXT image_path = osp.join(self._images_dir, self._subset, image_filename) if not osp.isfile(image_path): # backward compatibility old_image_path = osp.join(self._images_dir, image_filename) if osp.isfile(old_image_path): image_path = old_image_path media = Image.from_file(path=image_path, size=image_info.get("size")) if self.media_type == MediaElement: self.media_type = Image pcd_info = item_desc.get("point_cloud") if media and pcd_info: raise MediaTypeError("Dataset cannot contain multiple media types") if pcd_info: pcd_path = pcd_info.get("path") point_cloud = osp.join(self._pcd_dir, self._subset, pcd_path) related_images = None ri_info = item_desc.get("related_images") if ri_info: related_images = [ Image.from_file( size=ri.get("size"), path=osp.join(self._images_dir, self._subset, ri.get("path")), ) for ri in ri_info ] media = PointCloud.from_file(path=point_cloud, extra_images=related_images) if self.media_type == MediaElement: self.media_type = PointCloud video_frame_info = item_desc.get("video_frame") if media and video_frame_info: raise MediaTypeError("Dataset cannot contain multiple media types") if video_frame_info: video_path = osp.join( self._video_dir, self._subset, video_frame_info.get("video_path") ) if video_path not in self._videos: self._videos[video_path] = Video(video_path) video = self._videos[video_path] frame_index = video_frame_info.get("frame_index") media = VideoFrame(video, frame_index) media_desc = item_desc.get("media") if not media and media_desc and media_desc.get("path"): media = MediaElement(path=media_desc.get("path")) except Exception as e: self._ctx.error_policy.report_item_error( e, item_id=(item_desc.get("id", None), self._subset) ) return None annotations = self._load_annotations(item_desc) return DatasetItem( id=item_id, subset=self._subset, annotations=annotations, media=media, attributes=item_desc.get("attr"), ) def _load_annotations(self, item: Dict): loaded = [] for ann in item.get("annotations", []): try: ann_id = ann.get("id") ann_type = AnnotationType[ann["type"]] attributes = ann.get("attributes") group = ann.get("group") object_id = ann.get("object_id", NO_OBJECT_ID) label_id = ann.get("label_id") z_order = ann.get("z_order") points = ann.get("points") if ann_type == AnnotationType.label: loaded.append( Label( label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, ) ) elif ann_type == AnnotationType.mask: rle = ann["rle"] rle["counts"] = rle["counts"].encode("ascii") loaded.append( RleMask( rle=rle, label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, z_order=z_order, ) ) elif ann_type == AnnotationType.polyline: loaded.append( PolyLine( points, label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, z_order=z_order, ) ) elif ann_type == AnnotationType.polygon: loaded.append( Polygon( points, label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, z_order=z_order, ) ) elif ann_type == AnnotationType.bbox: x, y, w, h = ann["bbox"] loaded.append( Bbox( x, y, w, h, label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, z_order=z_order, ) ) elif ann_type == AnnotationType.points: loaded.append( Points( points, label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, z_order=z_order, ) ) elif ann_type == AnnotationType.caption: caption = ann.get("caption") loaded.append(Caption(caption, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.cuboid_3d: loaded.append( Cuboid3d( ann.get("position"), ann.get("rotation"), ann.get("scale"), label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, ) ) elif ann_type == AnnotationType.ellipse: loaded.append( Ellipse( *points, label=label_id, id=ann_id, attributes=attributes, group=group, object_id=object_id, z_order=z_order, ) ) elif ann_type == AnnotationType.hash_key: continue else: raise NotImplementedError() except Exception as e: self._ctx.error_policy.report_annotation_error( e, item_id=(ann.get("id", None), self._subset) ) return loaded def __len__(self): return len(self.items) def __iter__(self): yield from self.items
[docs] class StreamJsonReader(JsonReader): def __init__( self, path: str, subset: str, rootpath: str, images_dir: str, pcd_dir: str, video_dir: str, ctx: ImportContext, ) -> None: super().__init__(path, subset, rootpath, images_dir, pcd_dir, video_dir, ctx) self._length = None def __len__(self): return len(self._reader) def __iter__(self): ann_types = set() pbar = self._ctx.progress_reporter for item_desc in pbar.iter( self._reader, desc=f"Importing '{self._subset}'", ): item = self._parse_item(item_desc) yield item if item is not None: for ann in item.annotations: ann_types.add(ann.type) self.ann_types = ann_types def _init_reader(self, path: str) -> DatumPageMapper: return DatumPageMapper(path) @staticmethod def _load_media_type(page_mapper: DatumPageMapper) -> Type[MediaElement]: media_type = page_mapper.media_type if media_type is None: return MediaType.IMAGE.media return media_type.media @staticmethod def _load_ann_types(page_mapper: DatumPageMapper) -> Set[AnnotationType]: ann_types = page_mapper.ann_types if ann_types is None: return set() return ann_types @staticmethod def _load_infos(page_mapper: DatumPageMapper) -> Dict: return page_mapper.infos @staticmethod def _load_categories(page_mapper: DatumPageMapper) -> Dict: return JsonReader._load_categories({"categories": page_mapper.categories}) def _load_items(self, parsed) -> List: return []
[docs] class DatumaroBase(SubsetBase): LEGACY_VERSION = "legacy" CURRENT_DATUMARO_FORMAT_VERSION = DATUMARO_FORMAT_VERSION # If Datumaro format version goes up, it will be # ALLOWED_VERSIONS = {LEGACY_VERSION, 1.0, ..., CURRENT_DATUMARO_FORMAT_VERSION} ALLOWED_VERSIONS = {LEGACY_VERSION, CURRENT_DATUMARO_FORMAT_VERSION} def __init__( self, path: str, *, subset: Optional[str] = None, stream: bool = False, ctx: Optional[ImportContext] = None, ): assert osp.isfile(path), path subset = osp.splitext(osp.basename(path))[0] if subset is None else subset self._stream = stream super().__init__(subset=subset, ctx=ctx) self._init_path(path) dm_version = self._get_dm_format_version(path) # when backward compatibility happen, we should implement version specific readers if dm_version not in self.ALLOWED_VERSIONS: raise DatasetImportError( f"Datumaro format version of the given dataset is {dm_version}, " f"but not supported by this Datumaro version: {__version__}. " f"The allowed datumaro format versions are {self.ALLOWED_VERSIONS}. " "Please install the latest Datumaro." ) self._load_impl(path) def _init_path(self, path: str): """Initialize path variables""" rootpath = "" if path.endswith(osp.join(DatumaroPath.ANNOTATIONS_DIR, osp.basename(path))): rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0] self._rootpath = rootpath images_dir = "" if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.IMAGES_DIR)): images_dir = osp.join(rootpath, DatumaroPath.IMAGES_DIR) self._images_dir = images_dir pcd_dir = "" if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.PCD_DIR)): pcd_dir = osp.join(rootpath, DatumaroPath.PCD_DIR) self._pcd_dir = pcd_dir video_dir = "" if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.VIDEO_DIR)): video_dir = osp.join(rootpath, DatumaroPath.VIDEO_DIR) self._video_dir = video_dir @property def is_stream(self) -> bool: return self._stream
[docs] def infos(self): return self._reader.infos
[docs] def categories(self): return self._reader.categories
[docs] def media_type(self): return self._reader.media_type
[docs] def ann_types(self): return self._reader.ann_types
def _load_impl(self, path: str) -> None: """Actual implementation of loading Datumaro format.""" self._reader = ( JsonReader( path, self._subset, self._rootpath, self._images_dir, self._pcd_dir, self._video_dir, self._ctx, ) if not self._stream else StreamJsonReader( path, self._subset, self._rootpath, self._images_dir, self._pcd_dir, self._video_dir, self._ctx, ) ) return self._reader def _get_dm_format_version(self, path) -> str: """ Get Datumaro format at exporting the dataset Note that the legacy Datumaro doesn't store the version into exported dataset. Thus it returns DatumaroBase.LEGACY_VERSION """ # We can assume that the version information will be within the first 1 KB of the file. # This is because we are the producer of Datumaro format. search_size = 1024 # 1 KB pattern = '"dm_format_version"\s*:\s*"(\w+)"' with open(path, "r", encoding="utf-8") as fp: out = fp.read(search_size) found = re.search(pattern, out) if not found: return self.LEGACY_VERSION version_str = found.group(1) if re.match("\d+\.d+", version_str) is None: raise DatasetImportError(f"Invalid version string: {version_str} ") return version_str def __len__(self) -> int: return len(self._reader) def __iter__(self) -> DatasetItem: yield from self._reader