Source code for datumaro.plugins.data_formats.datumaro_binary.base

# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
import struct
from io import BufferedReader
from multiprocessing.pool import AsyncResult, Pool
from typing import Any, Dict, List, Optional

from datumaro.components.crypter import NULL_CRYPTER, Crypter
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.errors import DatasetImportError
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image, MediaElement, MediaType, PointCloud, Video, VideoFrame
from datumaro.plugins.data_formats.datumaro_binary.format import DatumaroBinaryPath
from datumaro.plugins.data_formats.datumaro_binary.mapper import DictMapper
from datumaro.plugins.data_formats.datumaro_binary.mapper.common import IntListMapper
from datumaro.plugins.data_formats.datumaro_binary.mapper.dataset_item import DatasetItemMapper

from ..datumaro.base import DatumaroBase, JsonReader


[docs] class DatumaroBinaryBase(DatumaroBase): """""" def __init__( self, path: str, *, encryption_key: Optional[bytes] = None, num_workers: int = 0, subset: Optional[str] = None, ctx: Optional[ImportContext] = None, ): """ Parameters ---------- path Directory path to import DatumaroBinary format dataset encryption_key If the dataset is encrypted, it (secret key) is needed to import the dataset. num_workers The number of multi-processing workers for import. If num_workers = 0, do not use multiprocessing. """ self._fp: Optional[BufferedReader] = None self._crypter = Crypter(encryption_key) if encryption_key is not None else NULL_CRYPTER self._media_encryption = False self._num_workers = num_workers super().__init__(path, subset=subset, ctx=ctx) def _get_dm_format_version(self, path: str) -> str: with open(path, "rb") as fp: self._fp = fp self._check_signature() dm_format_version = self._read_version() return dm_format_version def _load_impl(self, path: str) -> None: """Actual implementation of loading Datumaro binary format.""" try: with open(path, "rb") as fp: self._fp = fp self._check_signature() self._read_version() self._check_encryption_field() self._read_info() self._read_categories() self._read_media_type() self._read_items() finally: self._fp = None def _check_signature(self): signature = self._fp.read(DatumaroBinaryPath.SIGNATURE_LEN).decode() DatumaroBinaryPath.check_signature(signature) def _check_encryption_field(self): len_byte = self._fp.read(4) _bytes = self._fp.read(struct.unpack("I", len_byte)[0]) if not self._crypter.handshake(_bytes): raise DatasetImportError("Encryption key handshake fails. You give a wrong key.") def _read_header(self, use_crypter: bool = True): len_byte = self._fp.read(4) _bytes = self._fp.read(struct.unpack("I", len_byte)[0]) if use_crypter: _bytes = self._crypter.decrypt(_bytes) header, _ = DictMapper.backward(_bytes) return header def _read_version(self) -> Dict[str, Any]: version_header = self._read_header(use_crypter=False) self._media_encryption = version_header["media_encryption"] return version_header["dm_format_version"] def _read_info(self): self._infos = self._read_header() def _read_categories(self): categories = self._read_header() self._categories = JsonReader._load_categories({"categories": categories}) def _read_media_type(self): media_type = self._read_header()["media_type"] if media_type == MediaType.IMAGE: self._media_type = Image elif media_type == MediaType.POINT_CLOUD: self._media_type = PointCloud elif media_type == MediaType.VIDEO: self._media_type = Video elif media_type == MediaType.VIDEO_FRAME: self._media_type = VideoFrame elif media_type == MediaType.MEDIA_ELEMENT: self._media_type = MediaElement else: raise NotImplementedError(f"media_type={media_type} is currently not supported.") def _read_items(self) -> None: (n_blob_sizes_bytes,) = struct.unpack("<I", self._fp.read(4)) blob_sizes_bytes = self._crypter.decrypt(self._fp.read(n_blob_sizes_bytes)) blob_sizes, _ = IntListMapper.backward(blob_sizes_bytes, 0) media_path_prefix = { MediaType.IMAGE: osp.join(self._images_dir, self._subset), MediaType.POINT_CLOUD: osp.join(self._pcd_dir, self._subset), MediaType.VIDEO: osp.join(self._video_dir, self._subset), MediaType.VIDEO_FRAME: osp.join(self._video_dir, self._subset), } if self._num_workers > 0: self._items = self._read_items_mp(blob_sizes, media_path_prefix) else: self._items = self._read_items_sp(blob_sizes, media_path_prefix) for item in self._items: if item.media is not None and self._media_encryption: item.media.set_crypter(self._crypter) for ann in item.annotations: self._ann_types.add(ann.type) def _read_items_mp( self, blob_sizes: List[int], media_path_prefix: Dict[MediaType, str] ) -> List[DatasetItem]: async_results: List[AsyncResult] = [] with Pool(processes=self._num_workers) as pool: for blob_size in blob_sizes: blob_bytes = self._fp.read(blob_size) async_results += [ pool.apply_async( self._read_blob, ( blob_bytes, self._crypter, media_path_prefix, ), ) ] return [ item for async_result in async_results for item in async_result.get(timeout=DatumaroBinaryPath.MP_TIMEOUT) ] def _read_items_sp( self, blob_sizes: List[int], media_path_prefix: Dict[MediaType, str] ) -> List[DatasetItem]: items_list = [ self._read_blob(self._fp.read(blob_size), self._crypter, media_path_prefix) for blob_size in blob_sizes ] return [item for items in items_list for item in items] @staticmethod def _read_blob( blob_bytes: bytes, crypter: Crypter, media_path_prefix: Dict[MediaType, str] ) -> List[DatasetItem]: items = [] offset = 0 # Decrypt bytes first blob_bytes = crypter.decrypt(blob_bytes) # Extract items while offset < len(blob_bytes): item, offset = DatasetItemMapper.backward(blob_bytes, offset, media_path_prefix) items.append(item) assert offset == len(blob_bytes) return items @property def is_stream(self) -> bool: return False
[docs] def infos(self): return self._infos
[docs] def categories(self): return self._categories
[docs] def media_type(self): return self._media_type
[docs] def ann_types(self): return self._ann_types
def __len__(self) -> int: return len(self._items) def __iter__(self) -> DatasetItem: yield from self._items