Source code for datumaro.plugins.data_formats.tf_detection_api.base

# Copyright (C) 2019-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import os.path as osp
import re
from collections import OrderedDict
from typing import List, Optional

import numpy as np

from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories, Mask
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.lazy_plugin import extra_deps
from datumaro.components.media import Image
from datumaro.util.image import decode_image, lazy_image
from datumaro.util.tf_util import has_feature
from datumaro.util.tf_util import import_tf as _import_tf

from .format import DetectionApiPath, TfrecordImporterType

tf = _import_tf()


[docs] def clamp(value, _min, _max): return max(min(_max, value), _min)
[docs] @extra_deps("tensorflow") class TfDetectionApiBase(SubsetBase): def __init__( self, path: str, *, tfrecord_importer_type: TfrecordImporterType = TfrecordImporterType.default, subset: Optional[str] = None, ctx: Optional[ImportContext] = None, ): assert osp.isfile(path), path images_dir = "" root_dir = osp.dirname(osp.abspath(path)) if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR: root_dir = osp.dirname(root_dir) images_dir = osp.join(root_dir, DetectionApiPath.IMAGES_DIR) if not osp.isdir(images_dir): images_dir = "" if not subset: subset = osp.splitext(osp.basename(path))[0] super().__init__(subset=subset, ctx=ctx) self._features = { "image/filename": tf.io.FixedLenFeature([], tf.string), "image/source_id": tf.io.FixedLenFeature([], tf.string), "image/height": tf.io.FixedLenFeature([], tf.int64), "image/width": tf.io.FixedLenFeature([], tf.int64), "image/encoded": tf.io.FixedLenFeature([], tf.string), "image/format": tf.io.FixedLenFeature([], tf.string), # use varlen to avoid errors when this field is missing "image/key/sha256": tf.io.VarLenFeature(tf.string), # Object boxes and classes. "image/object/bbox/xmin": tf.io.VarLenFeature(tf.float32), "image/object/bbox/xmax": tf.io.VarLenFeature(tf.float32), "image/object/bbox/ymin": tf.io.VarLenFeature(tf.float32), "image/object/bbox/ymax": tf.io.VarLenFeature(tf.float32), "image/object/class/label": tf.io.VarLenFeature(tf.int64), "image/object/class/text": tf.io.VarLenFeature(tf.string), "image/object/mask": tf.io.VarLenFeature(tf.string), } if tfrecord_importer_type == TfrecordImporterType.roboflow: del self._features["image/source_id"] items, labels = self._parse_tfrecord_file(path, self._subset, images_dir) self._items = items self._categories = self._load_categories(labels) @staticmethod def _load_categories(labels): label_categories = LabelCategories().from_iterable( e[0] for e in sorted(labels.items(), key=lambda item: item[1]) ) return {AnnotationType.label: label_categories} @staticmethod def _parse_labelmap(text): id_pattern = r"(?:id\s*:\s*(?P<id>\d+))" name_pattern = r"(?:name\s*:\s*[\'\"](?P<name>.*?)[\'\"])" entry_pattern = r"(\{(?:[\s\n]*(?:%(id)s|%(name)s)[\s\n]*){2}\})+" % { "id": id_pattern, "name": name_pattern, } matches = re.finditer(entry_pattern, text) labelmap = {} for match in matches: label_id = match.group("id") label_name = match.group("name") if label_id is not None and label_name is not None: labelmap[label_name] = int(label_id) return labelmap def _parse_tfrecord_file(self, filepath, subset, images_dir): dataset = tf.data.TFRecordDataset(filepath) files = os.listdir(osp.dirname(filepath)) for filename in files: if DetectionApiPath.LABELMAP_FILE in filename: labelmap_path = osp.join(osp.dirname(filepath), filename) break dataset_labels = OrderedDict() if osp.exists(labelmap_path): with open(labelmap_path, "r", encoding="utf-8") as f: labelmap_text = f.read() dataset_labels.update( {label: id - 1 for label, id in self._parse_labelmap(labelmap_text).items()} ) dataset_items = [] for record in dataset: parsed_record = tf.io.parse_single_example(record, self._features) frame_id = parsed_record.get("image/source_id", None) frame_id = frame_id.numpy().decode("utf-8") if frame_id else frame_id frame_filename = parsed_record.get("image/filename", None).numpy().decode("utf-8") frame_height = tf.cast(parsed_record.get("image/height", 0), tf.int64).numpy().item() frame_width = tf.cast(parsed_record.get("image/width", 0), tf.int64).numpy().item() frame_image = parsed_record["image/encoded"].numpy() xmins = tf.sparse.to_dense(parsed_record["image/object/bbox/xmin"]).numpy() ymins = tf.sparse.to_dense(parsed_record["image/object/bbox/ymin"]).numpy() xmaxs = tf.sparse.to_dense(parsed_record["image/object/bbox/xmax"]).numpy() ymaxs = tf.sparse.to_dense(parsed_record["image/object/bbox/ymax"]).numpy() label_ids = tf.sparse.to_dense(parsed_record["image/object/class/label"]).numpy() labels = tf.sparse.to_dense( parsed_record["image/object/class/text"], default_value=b"" ).numpy() masks = tf.sparse.to_dense( parsed_record["image/object/mask"], default_value=b"" ).numpy() for label, label_id in zip(labels, label_ids): label = label.decode("utf-8") if not label: continue if label_id <= 0: continue if label in dataset_labels: continue dataset_labels[label] = label_id - 1 item_id = osp.splitext(frame_filename)[0] annotations = [] for shape_id, shape in enumerate(np.dstack((labels, xmins, ymins, xmaxs, ymaxs))[0]): label = shape[0].decode("utf-8") mask = None if len(masks) != 0: mask = masks[shape_id] if mask is not None: if isinstance(mask, bytes): mask = lazy_image(mask, decode_image) annotations.append(Mask(image=mask, label=dataset_labels.get(label))) else: x = clamp(shape[1] * frame_width, 0, frame_width) y = clamp(shape[2] * frame_height, 0, frame_height) w = clamp(shape[3] * frame_width, 0, frame_width) - x h = clamp(shape[4] * frame_height, 0, frame_height) - y annotations.append(Bbox(x, y, w, h, label=dataset_labels.get(label))) image_size = None if frame_height and frame_width: image_size = (frame_height, frame_width) image = None if frame_image: if isinstance(frame_image, np.ndarray): image = Image.from_numpy(data=frame_image, size=image_size) else: image = Image.from_bytes(data=frame_image, size=image_size) elif frame_filename: image = Image.from_file(path=osp.join(images_dir, frame_filename), size=image_size) dataset_items.append( DatasetItem( id=item_id, subset=subset, media=image, annotations=annotations, attributes={"source_id": frame_id}, ) ) return dataset_items, dataset_labels
[docs] @extra_deps("tensorflow") class TfDetectionApiImporter(Importer): _FORMAT_EXT = ".tfrecord"
[docs] @classmethod def find_sources(cls, path): sources = cls._find_sources_recursive( path=path, ext=cls._FORMAT_EXT, extractor_name="tf_detection_api", ) if len(sources) == 0: return [] desired_feature = { "image/source_id": tf.io.FixedLenFeature([], tf.string), } subsets = {} for source in sources: if has_feature(path=source["url"], feature=desired_feature): subset_name = osp.basename(source["url"]).split(".")[-2] subsets[subset_name] = source["url"] sources = [ { "url": url, "format": "tf_detection_api", } for _, url in subsets.items() ] return sources
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [cls._FORMAT_EXT]