Source code for datumaro.plugins.data_formats.segment_anything.importer
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
from typing import Dict, List, Optional
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import Importer
from datumaro.errors import DatasetImportError
from datumaro.rust_api import JsonSectionPageMapper
from datumaro.util import parse_json
[docs]
class SegmentAnythingImporter(Importer):
_N_JSON_TO_TEST = 10
_MAX_ANNOTATION_SECTION_BYTES = 100 * 1024 * 1024 # 100 MiB
_ANNO_EXT = ".json"
[docs]
@classmethod
def detect(
cls,
context: FormatDetectionContext,
) -> Optional[FormatDetectionConfidence]:
# test maximum 10 annotation files only
ctr = 0
for file in context.require_files_iter(f"*{cls._ANNO_EXT}"):
ctr += 1
with context.probe_text_file(
file, "Annotation format is not Segmentat-Anything format", is_binary_file=True
) as f:
fpath = os.path.join(context.root_path, file)
page_mapper = JsonSectionPageMapper(fpath)
sections = page_mapper.sections()
if set(sections.keys()) != {"annotations", "image"}:
raise DatasetImportError
offset, size = sections["image"]["offset"], sections["image"]["size"]
f.seek(offset, 0)
img_contents = parse_json(f.read(size))
if set(img_contents.keys()) != {
"image_id",
"width",
"height",
"file_name",
}:
raise DatasetImportError
offset, size = sections["annotations"]["offset"], sections["annotations"]["size"]
if size > cls._MAX_ANNOTATION_SECTION_BYTES:
msg = f"Annotation section is too huge. It exceeded {cls._MAX_ANNOTATION_SECTION_BYTES} bytes."
raise DatasetImportError(msg)
f.seek(offset, 0)
ann_contents = parse_json(f.read(size))
if not {"id", "segmentation", "bbox"}.issubset(set(ann_contents[0])):
raise DatasetImportError
if ctr > cls._N_JSON_TO_TEST:
break
[docs]
@classmethod
def find_sources(cls, path) -> List[Dict]:
if not os.path.isdir(path):
return []
return [{"url": path, "format": cls.NAME}]
[docs]
@classmethod
def get_file_extensions(cls) -> List[str]:
return [cls._ANNO_EXT]