Source code for datumaro.plugins.data_formats.segment_anything.exporter

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT


import logging as log
import os
import os.path as osp
from itertools import chain
from typing import List, Union

from pycocotools import mask as mask_utils

from datumaro.components.annotation import AnnotationType, Ellipse, Polygon
from datumaro.components.errors import MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.media import Image
from datumaro.util import annotation_util as anno_tools
from datumaro.util import dump_json_file, mask_tools


[docs] class SegmentAnythingExporter(Exporter): DEFAULT_IMAGE_EXT = ".jpg" _polygon_types = {AnnotationType.polygon, AnnotationType.ellipse} _allowed_types = { AnnotationType.bbox, AnnotationType.polygon, AnnotationType.mask, AnnotationType.ellipse, } def __init__( self, extractor, save_dir, **kwargs, ): super().__init__(extractor, save_dir, **kwargs)
[docs] @staticmethod def find_instance_anns(annotations): return [a for a in annotations if a.type in SegmentAnythingExporter._allowed_types]
[docs] @classmethod def find_instances(cls, annotations): return anno_tools.find_instances(cls.find_instance_anns(annotations))
[docs] def get_annotation_info(self, group, img_width, img_height): boxes = [a for a in group if a.type == AnnotationType.bbox] polygons: List[Union[Polygon, Ellipse]] = [ a for a in group if a.type in self._polygon_types ] masks = [a for a in group if a.type == AnnotationType.mask] anns = boxes + polygons + masks leader = anno_tools.find_group_leader(anns) if len(boxes) > 0: bbox = anno_tools.max_bbox(boxes) else: bbox = anno_tools.max_bbox(anns) polygons = [p.as_polygon() for p in polygons] mask = None if polygons: mask = mask_tools.rles_to_mask(polygons, img_width, img_height) if masks: masks = (m.image for m in masks) if mask is not None: masks = chain(masks, [mask]) mask = mask_tools.merge_masks(masks) if mask is None: return None mask = mask_tools.mask_to_rle(mask) segmentation = { "counts": list(int(c) for c in mask["counts"]), "size": list(int(c) for c in mask["size"]), } rles = mask_utils.frPyObjects(segmentation, img_height, img_width) if isinstance(rles["counts"], bytes): rles["counts"] = rles["counts"].decode() area = mask_utils.area(rles) annotation_data = { "id": leader.group, "segmentation": rles, "bbox": bbox, "area": area, "predicted_iou": max(ann.attributes.get("predicted_iou", 0.0) for ann in anns), "stability_score": max(ann.attributes.get("stability_score", 0.0) for ann in anns), "crop_box": anno_tools.max_bbox([ann.attributes.get("crop_box", []) for ann in anns]), "point_coords": list( set( tuple(point_coord) for ann in anns for point_coord in ann.attributes.get("point_coords", [[]]) ) ), } return annotation_data
def _apply_impl(self): if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): raise MediaTypeError("Media type is not an image") os.makedirs(self._save_dir, exist_ok=True) subsets = self._extractor.subsets() pbars = self._ctx.progress_reporter.split(len(subsets)) max_image_id = 1 for pbar, (subset_name, subset) in zip(pbars, subsets.items()): for item in pbar.iter(subset, desc=f"Exporting {subset_name}"): try: # make sure file_name is flat file_name = self._make_image_filename(item).replace("/", "__") try: image_id = int(item.attributes.get("id", max_image_id)) except ValueError: image_id = max_image_id max_image_id += 1 if not item.media or not item.media.size: log.warning( f"Item '{item.id}': skipping writing instances since no image info available" ) continue height, width = item.media.size json_data = { "image": { "image_id": image_id, "file_name": file_name, "height": height, "width": width, }, "annotations": [], } instances = self.find_instances(item.annotations) annotations = [self.get_annotation_info(i, width, height) for i in instances] annotations = [i for i in annotations if i is not None] if not annotations: log.warning( f"Item '{item.id}': skipping writing instances since no annotation available" ) continue json_data["annotations"] = annotations dump_json_file( os.path.join(self._save_dir, osp.splitext(file_name)[0] + ".json"), json_data, ) if self._save_media: self._save_image( item, path=osp.abspath(osp.join(self._save_dir, file_name)), ) except Exception as e: self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset))