Source code for datumaro.plugins.data_formats.mapillary_vistas.importer

# Copyright (C) 2022-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import glob
import logging as log
import os.path as osp
from typing import List

from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME
from datumaro.components.errors import DatasetNotFoundError
from datumaro.components.importer import Importer
from datumaro.util import str_to_bool

from .base import MapillaryVistasInstancesBase, MapillaryVistasPanopticBase
from .format import MapillaryVistasPath, MapillaryVistasTask


[docs] class MapillaryVistasImporter(Importer): _TASKS = { MapillaryVistasTask.instances: MapillaryVistasInstancesBase, MapillaryVistasTask.panoptic: MapillaryVistasPanopticBase, }
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "--format-version", default="v2.0", type=str, help="Use original config*.json file for your version of dataset", ) parser.add_argument( "--parse-polygon", type=str_to_bool, default=False, help="Use original config*.json file for your version of dataset", ) parser.add_argument( "--use-original-config", action="store_true", help="Use original config*.json file for your version of dataset", ) parser.add_argument( "--keep-original-category-ids", action="store_true", help="Add dummy label categories so that category indices " "correspond to the category IDs in the original annotation " "file", ) return parser
def __call__(self, path, **extra_params): subsets = self.find_sources(path) if len(subsets) == 0: raise DatasetNotFoundError(path, self.NAME) tasks = list(set(task for subset in subsets.values() for task in subset)) selected_task = tasks[0] if 1 < len(tasks): task_types = ",".join(task.name for task in tasks) log.warning( f"Found potentially conflicting source types: {task_types}" f"Only one one type will be used: {selected_task.name}" ) if selected_task == MapillaryVistasTask.instances: has_config = any( [ osp.isfile(osp.join(path, config)) for config in MapillaryVistasPath.CONFIG_FILES.values() ] ) if not has_config and not extra_params.get("use_original_config"): raise DatasetNotFoundError( path, self.NAME, "Failed to find config*.json at '{path}'. " "See extra args for using original configs.", ) sources = [ {"url": url, "format": self._TASKS[task].NAME, "options": dict(extra_params)} for _, subset_info in subsets.items() for task, url in subset_info.items() if task == selected_task ] return sources
[docs] @classmethod def find_sources(cls, path): subsets = {} suffixes = [ osp.join(ann_dir, subdir) for ann_dir, subdirs in MapillaryVistasPath.ANNOTATION_DIRS.items() for subdir in subdirs ] for suffix in suffixes: task = MapillaryVistasPath.CLASS_BY_DIR[osp.basename(suffix)] if task not in cls._TASKS: continue if osp.isdir(osp.join(path, suffix)): return {DEFAULT_SUBSET_NAME: {task: path}} for ann_path in glob.glob(osp.join(path, "*", suffix)): subset = osp.dirname(osp.dirname(osp.relpath(ann_path, path))) subsets.setdefault(subset, {})[task] = osp.join(path, subset) return subsets
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [".jpg", ".png", ".json"]
[docs] class MapillaryVistasInstancesImporter(MapillaryVistasImporter): _TASK = MapillaryVistasTask.instances _TASKS = {_TASK: MapillaryVistasImporter._TASKS[_TASK]}
[docs] class MapillaryVistasPanopticImporter(MapillaryVistasImporter): _TASK = MapillaryVistasTask.panoptic _TASKS = {_TASK: MapillaryVistasImporter._TASKS[_TASK]}