Source code for datumaro.plugins.data_formats.roboflow.base_tfrecord

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import re
from typing import List, Optional

from datumaro.components.importer import ImportContext, Importer
from datumaro.components.lazy_plugin import extra_deps
from datumaro.plugins.data_formats.tf_detection_api.base import TfDetectionApiBase
from datumaro.plugins.data_formats.tf_detection_api.format import TfrecordImporterType
from datumaro.util.tf_util import has_feature
from datumaro.util.tf_util import import_tf as _import_tf

tf = _import_tf()


[docs] @extra_deps("tensorflow") class RoboflowTfrecordImporter(Importer): _ANNO_EXT = ".tfrecord"
[docs] @classmethod def find_sources(cls, path): sources = cls._find_sources_recursive( path=path, ext=cls._ANNO_EXT, extractor_name="roboflow_tfrecord", ) if len(sources) == 0: return [] undesired_feature = { "image/source_id": tf.io.FixedLenFeature([], tf.string), } subsets = {} for source in sources: if has_feature(path=source["url"], feature=undesired_feature): continue subset_name = os.path.dirname(source["url"]).split(os.sep)[-1] subsets[subset_name] = source["url"] sources = [ { "url": url, "format": "roboflow_tfrecord", "options": { "subset": subset, }, } for subset, url in subsets.items() ] return sources
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [cls._ANNO_EXT]
[docs] class RoboflowTfrecordBase(TfDetectionApiBase): def __init__( self, path: str, *, subset: Optional[str] = None, ctx: Optional[ImportContext] = None, ): super().__init__( path=path, subset=subset, tfrecord_importer_type=TfrecordImporterType.roboflow, ctx=ctx, ) @staticmethod def _parse_labelmap(text): entry_pattern = r'name:\s*"([^"]+)"\s*,\s*id:\s*(\d+)' entry_pattern = re.compile(entry_pattern) matches = re.findall(entry_pattern, text) labelmap = {name: int(id) for name, id in matches} return labelmap