Source code for datumaro.plugins.data_formats.vott_csv

# Copyright (C) 2021-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import csv
import errno
import os.path as osp
from typing import List, Optional

from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.media import Image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


[docs] class VottCsvPath: ANNO_FILE_SUFFIX = "-export.csv"
[docs] class VottCsvBase(SubsetBase): def __init__( self, path: str, *, subset: Optional[str] = None, ctx: Optional[ImportContext] = None, ): if not osp.isfile(path): raise FileNotFoundError(errno.ENOENT, "Can't find annotations file", path) if not subset: subset = osp.splitext(osp.basename(path))[0].rsplit("-", maxsplit=1)[0] super().__init__(subset=subset, ctx=ctx) if has_meta_file(path): self._categories = { AnnotationType.label: LabelCategories.from_iterable(parse_meta_file(path).keys()) } else: self._categories = {AnnotationType.label: LabelCategories()} self._items = list(self._load_items(path).values()) def _load_items(self, path): items = {} label_categories = self._categories[AnnotationType.label] with open(path, encoding="utf-8") as content: for row in csv.DictReader(content): item_id = osp.splitext(row["image"])[0] if item_id not in items: items[item_id] = DatasetItem( id=item_id, subset=self._subset, media=Image.from_file(path=osp.join(osp.dirname(path), row["image"])), ) annotations = items[item_id].annotations label_name = row.get("label") x_min = row.get("xmin") y_min = row.get("ymin") x_max = row.get("xmax") y_max = row.get("ymax") if label_name and x_min and y_min and x_max and y_max: label_idx = label_categories.find(label_name)[0] if label_idx is None: label_idx = label_categories.add(label_name) x_min = float(x_min) y_min = float(y_min) w = float(x_max) - x_min h = float(y_max) - y_min annotations.append(Bbox(x_min, y_min, w, h, label=label_idx)) self._ann_types.add(AnnotationType.bbox) return items
[docs] class VottCsvImporter(Importer):
[docs] @classmethod def find_sources(cls, path): return cls._find_sources_recursive(path, ".csv", "vott_csv")
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> None: context.require_file("*" + VottCsvPath.ANNO_FILE_SUFFIX)
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [osp.splitext(VottCsvPath.ANNO_FILE_SUFFIX)[1]]