Source code for datumaro.plugins.data_formats.segment_anything.base

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
from glob import glob
from inspect import isclass
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

from datumaro.components.annotation import Bbox, RleMask
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import (
    DatasetImportError,
    InvalidAnnotationError,
    InvalidFieldTypeError,
    MissingFieldError,
)
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image
from datumaro.util import NOTSET, parse_json_file

T = TypeVar("T")


[docs] def parse_field( ann: Dict[str, Any], key: str, cls: Union[Type[T], Tuple[Type, ...]], default: Any = NOTSET, ) -> Any: value = ann.get(key, NOTSET) if value is NOTSET: if default is not NOTSET: return default raise MissingFieldError(key) elif not isinstance(value, cls): cls = (cls,) if isclass(cls) else cls raise InvalidFieldTypeError( key, actual=str(type(value)), expected=tuple(str(t) for t in cls) ) return value
[docs] class SegmentAnythingBase(SubsetBase): def __init__( self, path: str, *, subset: Optional[str] = None, ctx: Optional[ImportContext] = None, ): if not osp.isdir(path): raise DatasetImportError(f"path {path} must be directory.") self._path = path super().__init__(subset=subset, ctx=ctx) self._items = self._load_items() def _load_items(self): pbar = self._ctx.progress_reporter items = [] for annotation_file in pbar.iter( glob(osp.join(self._path, "*.json")), desc=f"Parsing data in {osp.basename(self._path)}", ): image_id = None annotations = [] item_kwargs = { "id": None, "subset": self._subset, "media": None, "annotations": [], "attributes": {}, } try: contents = parse_json_file(annotation_file) image_info = contents["image"] annotations = contents["annotations"] image_id = parse_field(image_info, "image_id", int) item_kwargs["attributes"]["id"] = image_id image_size = ( parse_field(image_info, "height", int, default=None), parse_field(image_info, "width", int, default=None), ) if any(i is None for i in image_size): image_size = None file_name = parse_field(image_info, "file_name", str) item_kwargs["id"] = osp.splitext(file_name)[0] item_kwargs["media"] = Image.from_file( path=osp.join(self._path, file_name), size=image_size ) except Exception as e: self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset)) try: for annotation in annotations: anno_id = parse_field(annotation, "id", int) attributes = { "predicted_iou": parse_field( annotation, "predicted_iou", float, 0.0, ), "stability_score": parse_field( annotation, "stability_score", float, 0.0, ), "point_coords": parse_field( annotation, "point_coords", list, [[]], ), "crop_box": parse_field(annotation, "crop_box", list, []), } group = anno_id # make sure all tasks' annotations are merged segmentation = parse_field(annotation, "segmentation", dict, None) if segmentation is None: raise InvalidAnnotationError("'segmentation' label is not found.") item_kwargs["annotations"].append( RleMask( rle=segmentation, label=None, id=anno_id, attributes=attributes, group=group, ) ) bbox = parse_field(annotation, "bbox", list, None) if bbox is None: bbox = item_kwargs["annotations"][-1].get_bbox().tolist() if len(bbox) > 0: if len(bbox) != 4: raise InvalidAnnotationError( f"Bbox has wrong value count {len(bbox)}. Expected 4 values." ) x, y, w, h = bbox item_kwargs["annotations"].append( Bbox( x, y, w, h, label=None, id=anno_id, attributes=attributes, group=group, ) ) except Exception as e: self._ctx.error_policy.report_annotation_error(e, item_id=(image_id, self._subset)) try: items.append(DatasetItem(**item_kwargs)) except Exception as e: self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset)) return items