Source code for datumaro.plugins.data_formats.kitti.exporter

# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
import os
import os.path as osp
from collections import OrderedDict
from enum import Enum, auto

import numpy as np

from datumaro.components.annotation import AnnotationType, CompiledMask, LabelCategories
from datumaro.components.errors import InvalidAnnotationError, MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.media import Image
from datumaro.util import cast, parse_str_enum_value, str_to_bool
from datumaro.util.annotation_util import make_label_id_mapping
from datumaro.util.image import save_image
from datumaro.util.mask_tools import paint_mask
from datumaro.util.meta_file_util import is_meta_file, parse_meta_file

from .format import (
    KittiLabelMap,
    KittiPath,
    KittiTask,
    make_kitti_categories,
    parse_label_map,
    write_label_map,
)


[docs] class LabelmapType(Enum): kitti = auto() source = auto()
[docs] class KittiExporter(Exporter): DEFAULT_IMAGE_EXT = KittiPath.IMAGE_EXT @staticmethod def _split_tasks_string(s): return [KittiTask[i.strip().lower()] for i in s.split(",")] @staticmethod def _get_labelmap(s): if osp.isfile(s): return s try: return LabelmapType[s.lower()].name except KeyError: import argparse raise argparse.ArgumentTypeError()
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "--apply-colormap", type=str_to_bool, default=True, help="Use colormap for class masks (default: %(default)s)", ) parser.add_argument( "--label-map", type=cls._get_labelmap, default=None, help="Labelmap file path or one of %s" % ", ".join(t.name for t in LabelmapType), ) parser.add_argument( "--tasks", type=cls._split_tasks_string, help="KITTI task filter, comma-separated list of {%s} " "(default: all)" % ", ".join(t.name for t in KittiTask), ) return parser
def __init__( self, extractor, save_dir, tasks=None, apply_colormap=True, allow_attributes=True, label_map=None, **kwargs, ): super().__init__(extractor, save_dir, **kwargs) assert tasks is None or isinstance(tasks, (KittiTask, list, set)) if tasks is None: tasks = set(KittiTask) elif isinstance(tasks, KittiTask): tasks = {tasks} else: tasks = set(parse_str_enum_value(t, KittiTask) for t in tasks) self._tasks = tasks self._apply_colormap = apply_colormap if label_map is None: label_map = LabelmapType.source.name if KittiTask.segmentation in self._tasks: self._load_categories(label_map) elif KittiTask.detection in self._tasks: self._categories = { AnnotationType.label: self._extractor.categories().get( AnnotationType.label, LabelCategories() ) } def _apply_impl(self): if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): raise MediaTypeError("Media type is not an image") os.makedirs(self._save_dir, exist_ok=True) for subset_name, subset in self._extractor.subsets().items(): if KittiTask.segmentation in self._tasks: os.makedirs( osp.join(self._save_dir, subset_name, KittiPath.INSTANCES_DIR), exist_ok=True ) for item in subset: if self._save_media: self._save_image(item, subdir=osp.join(subset_name, KittiPath.IMAGES_DIR)) masks = [a for a in item.annotations if a.type == AnnotationType.mask] if masks and KittiTask.segmentation in self._tasks: compiled_class_mask = CompiledMask.from_instance_masks( masks, instance_labels=[self._label_id_mapping(m.label) for m in masks] ) color_mask_path = osp.join( subset_name, KittiPath.SEMANTIC_RGB_DIR, item.id + KittiPath.MASK_EXT ) self.save_mask( osp.join(self._save_dir, color_mask_path), compiled_class_mask.class_mask ) labelids_mask_path = osp.join( subset_name, KittiPath.SEMANTIC_DIR, item.id + KittiPath.MASK_EXT ) self.save_mask( osp.join(self._save_dir, labelids_mask_path), compiled_class_mask.class_mask, apply_colormap=False, dtype=np.int32, ) # TODO: optimize second merging compiled_instance_mask = CompiledMask.from_instance_masks( masks, instance_labels=[ (self._label_id_mapping(m.label) << 8) + m.id for m in masks ], ) inst_path = osp.join( subset_name, KittiPath.INSTANCES_DIR, item.id + KittiPath.MASK_EXT ) self.save_mask( osp.join(self._save_dir, inst_path), compiled_instance_mask.class_mask, apply_colormap=False, dtype=np.int32, ) bboxes = [a for a in item.annotations if a.type == AnnotationType.bbox] if bboxes and KittiTask.detection in self._tasks: labels_file = osp.join( self._save_dir, subset_name, KittiPath.LABELS_DIR, "%s.txt" % item.id ) os.makedirs(osp.dirname(labels_file), exist_ok=True) with open(labels_file, "w", encoding="utf-8") as f: for bbox in bboxes: label_line = [-1] * 16 label_line[0] = self.get_label(bbox.label) label_line[1] = cast( bbox.attributes.get("truncated"), float, KittiPath.DEFAULT_TRUNCATED ) label_line[2] = cast( bbox.attributes.get("occluded"), int, KittiPath.DEFAULT_OCCLUDED ) x, y, h, w = bbox.get_bbox() label_line[4:8] = x, y, x + h, y + w label_line[15] = cast( bbox.attributes.get("score"), float, KittiPath.DEFAULT_SCORE ) label_line = " ".join(str(v) for v in label_line) f.write("%s\n" % label_line) if KittiTask.segmentation in self._tasks: self.save_label_map()
[docs] def get_label(self, label_id): return self._extractor.categories()[AnnotationType.label].items[label_id].name
[docs] def save_label_map(self): if self._save_dataset_meta: self._save_meta_file(self._save_dir) else: path = osp.join(self._save_dir, KittiPath.LABELMAP_FILE) write_label_map(path, self._label_map)
def _load_categories(self, label_map_source): if label_map_source == LabelmapType.kitti.name: # use the default KITTI colormap label_map = KittiLabelMap elif ( label_map_source == LabelmapType.source.name and AnnotationType.mask not in self._extractor.categories() ): # generate colormap for input labels labels = self._extractor.categories().get(AnnotationType.label, LabelCategories()) label_map = OrderedDict((item.name, None) for item in labels.items) elif ( label_map_source == LabelmapType.source.name and AnnotationType.mask in self._extractor.categories() ): # use source colormap labels = self._extractor.categories()[AnnotationType.label] colors = self._extractor.categories()[AnnotationType.mask] label_map = OrderedDict() for idx, item in enumerate(labels.items): color = colors.colormap.get(idx) if color is not None: label_map[item.name] = color elif isinstance(label_map_source, dict): label_map = OrderedDict(sorted(label_map_source.items(), key=lambda e: e[0])) elif isinstance(label_map_source, str) and osp.isfile(label_map_source): if is_meta_file(label_map_source): label_map = parse_meta_file(label_map_source) else: label_map = parse_label_map(label_map_source) else: raise InvalidAnnotationError( "Wrong labelmap specified, " "expected one of %s or a file path" % ", ".join(t.name for t in LabelmapType) ) self._categories = make_kitti_categories(label_map) self._label_map = label_map self._label_id_mapping = self._make_label_id_map() def _make_label_id_map(self): map_id, id_mapping, src_labels, dst_labels = make_label_id_mapping( self._extractor.categories().get(AnnotationType.label), self._categories[AnnotationType.label], ) void_labels = [ src_label for src_id, src_label in src_labels.items() if src_label not in dst_labels ] if void_labels: log.warning( "The following labels are remapped to background: %s" % ", ".join(void_labels) ) log.debug( "Saving segmentations with the following label mapping: \n%s" % "\n".join( [ "#%s '%s' -> #%s '%s'" % ( src_id, src_label, id_mapping[src_id], self._categories[AnnotationType.label].items[id_mapping[src_id]].name, ) for src_id, src_label in src_labels.items() ] ) ) return map_id
[docs] def save_mask(self, path, mask, colormap=None, apply_colormap=True, dtype=np.uint8): if self._apply_colormap and apply_colormap: if colormap is None: colormap = self._categories[AnnotationType.mask].colormap mask = paint_mask(mask, colormap) save_image(path, mask, create_dir=True, dtype=dtype)
[docs] class KittiSegmentationExporter(KittiExporter): def __init__(self, *args, **kwargs): kwargs["tasks"] = KittiTask.segmentation super().__init__(*args, **kwargs)
[docs] class KittiDetectionExporter(KittiExporter): def __init__(self, *args, **kwargs): kwargs["tasks"] = KittiTask.detection super().__init__(*args, **kwargs)