Source code for datumaro.plugins.data_formats.kitti.importer

# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
import os.path as osp
from glob import glob
from typing import List

from datumaro.components.errors import DatasetNotFoundError
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import Importer

from .format import KittiPath, KittiTask


[docs] class KittiImporter(Importer): DETECT_CONFIDENCE = FormatDetectionConfidence.MEDIUM _TASKS = { KittiTask.segmentation: ("kitti_segmentation", KittiPath.INSTANCES_DIR), KittiTask.detection: ("kitti_detection", KittiPath.LABELS_DIR), } def __call__(self, path, **extra_params): subsets = self.find_sources(path) if len(subsets) == 0: raise DatasetNotFoundError(path, self.NAME) # TODO: should be removed when proper label merging is implemented conflicting_types = {"kitti_segmentation", "kitti_detection"} ann_types = set(t for s in subsets.values() for t in s) & conflicting_types if 1 <= len(ann_types): selected_ann_type = sorted(ann_types)[0] if 1 < len(ann_types): log.warning( "Not implemented: " "Found potentially conflicting source types with labels: %s. " "Only one type will be used: %s" % (", ".join(ann_types), selected_ann_type) ) sources = [] for ann_files in subsets.values(): for ann_type, ann_file in ann_files.items(): if ann_type in conflicting_types: if ann_type is not selected_ann_type: log.warning( "Not implemented: " "conflicting source '%s' is skipped." % ann_file ) continue log.info("Found a dataset at '%s'" % ann_file) sources.append( { "url": ann_file, "format": ann_type, "options": dict(extra_params), } ) return sources
[docs] @classmethod def find_sources(cls, path): subsets = {} for extractor_type, task_dir in cls._TASKS.values(): subset_paths = glob(osp.join(path, "**", task_dir), recursive=True) for subset_path in subset_paths: path = osp.normpath(osp.join(subset_path, "..")) subset_name = osp.splitext(osp.basename(path))[0] subsets.setdefault(subset_name, {})[extractor_type] = path return subsets
[docs] @classmethod def detect( cls, context: FormatDetectionContext, ) -> FormatDetectionConfidence: sub_importers = [KittiDetectionImporter, KittiSegmentationImporter] with context.require_any(): for importer_cls in sub_importers: with context.alternative(): importer_cls.detect(context)
[docs] @classmethod def get_file_extensions(cls) -> List[str]: sub_importers = [KittiDetectionImporter, KittiSegmentationImporter] return list({ext for importer in sub_importers for ext in importer.get_file_extensions()})
[docs] class KittiDetectionImporter(KittiImporter): _TASK = KittiTask.detection _TASKS = {_TASK: KittiImporter._TASKS[_TASK]} _ANNO_EXT = ".txt"
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence: # left color camera label files context.require_file(f"**/label_2/*_*{cls._ANNO_EXT}") return cls.DETECT_CONFIDENCE
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [cls._ANNO_EXT]
[docs] class KittiSegmentationImporter(KittiImporter): _TASK = KittiTask.segmentation _TASKS = {_TASK: KittiImporter._TASKS[_TASK]} _FORMAT_EXT = ".png"
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence: # instance segmentation masks context.require_file(f"**/instance/*{cls._FORMAT_EXT}") return cls.DETECT_CONFIDENCE
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [cls._FORMAT_EXT]