Source code for datumaro.plugins.data_formats.datumaro_binary.exporter

# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

# pylint: disable=no-self-use

import argparse
import logging as log
import os.path as osp
import struct
import warnings
from io import BufferedWriter
from multiprocessing.pool import ApplyResult, Pool
from typing import Any, List, Optional, Union

from datumaro.components.crypter import NULL_CRYPTER, Crypter
from datumaro.components.dataset_base import DatasetItem, IDataset
from datumaro.components.errors import DatumaroError, PathSeparatorInSubsetNameError
from datumaro.components.exporter import ExportContext, ExportContextComponent, Exporter
from datumaro.plugins.data_formats.datumaro.exporter import DatumaroExporter
from datumaro.plugins.data_formats.datumaro.exporter import _SubsetWriter as __SubsetWriter
from datumaro.plugins.data_formats.datumaro.format import DATUMARO_FORMAT_VERSION

from .format import DatumaroBinaryPath
from .mapper import DictMapper
from .mapper.common import IntListMapper
from .mapper.dataset_item import DatasetItemMapper


class _SubsetWriter(__SubsetWriter):
    """"""

    def __init__(
        self,
        context: Exporter,
        subset: str,
        ann_file: str,
        export_context: ExportContextComponent,
        secret_key_file: str,
        no_media_encryption: bool = False,
        max_blob_size: int = DatumaroBinaryPath.MAX_BLOB_SIZE,
    ):
        super().__init__(context, subset, ann_file, export_context)
        self._crypter = self.export_context.crypter
        self.secret_key_file = secret_key_file

        self._fp: Optional[BufferedWriter] = None
        self._data["items"]: List[Union[bytes, ApplyResult]] = []
        self._bytes: List[Union[bytes, ApplyResult]] = self._data["items"]
        self._item_cnt = 0
        media_type = context._extractor.media_type()
        self._media_type = {"media_type": media_type._type}

        self._media_encryption = not no_media_encryption

        if max_blob_size != DatumaroBinaryPath.MAX_BLOB_SIZE:
            warnings.warn(
                f"You provide max_blob_size={max_blob_size}, "
                "but it is not recommended to provide an arbitrary max_blob_size."
            )

        self._max_blob_size = max_blob_size

    def _sign(self):
        self._fp.write(DatumaroBinaryPath.SIGNATURE.encode())

    def _dump_encryption_field(self) -> int:
        if self._crypter.key is None:
            msg = b""
        else:
            msg = self._crypter.key
            msg = self._crypter.encrypt(msg)

        return self._fp.write(struct.pack(f"I{len(msg)}s", len(msg), msg))

    def _dump_header(self, header: Any, use_crypter: bool = True):
        msg = DictMapper.forward(header)

        if use_crypter and self._crypter.key is not None:
            msg = self._crypter.encrypt(msg)

        length = struct.pack("I", len(msg))
        return self._fp.write(length + msg)

    def _dump_version(self):
        self._dump_header(
            {
                "dm_format_version": DATUMARO_FORMAT_VERSION,
                "media_encryption": self._media_encryption,
            },
            use_crypter=False,
        )

    def _dump_info(self):
        self._dump_header(self.infos)

    def _dump_categories(self):
        self._dump_header(self.categories)

    def _dump_media_type(self):
        self._dump_header(self._media_type)

    def add_item(self, item: DatasetItem, pool: Optional[Pool] = None, *args, **kwargs):
        if pool is not None:
            self._bytes.append(
                pool.apply_async(
                    self.add_item_impl,
                    (
                        item,
                        self.export_context,
                        self._media_encryption,
                    ),
                )
            )
        else:
            self._bytes.append(
                self.add_item_impl(item, self.export_context, self._media_encryption)
            )

        self._item_cnt += 1

    @staticmethod
    def add_item_impl(
        item: DatasetItem, context: ExportContextComponent, media_encryption: bool
    ) -> bytes:
        with _SubsetWriter.context_save_media(item, context=context, encryption=media_encryption):
            return DatasetItemMapper.forward(item)

    def _dump_items(self, pool: Optional[Pool] = None):
        # Await async results
        if pool is not None:
            self._bytes = [
                result.get(timeout=DatumaroBinaryPath.MP_TIMEOUT)
                for result in self._bytes
                if isinstance(result, ApplyResult)
            ]

        # Divide items to blobs
        blobs = [bytearray()]
        cur_blob = blobs[-1]
        for _bytes in self._bytes:
            cur_blob += _bytes

            if len(cur_blob) > self._max_blob_size:
                blobs += [bytearray()]
                cur_blob = blobs[-1]

        # Encrypt blobs
        blobs = [self._crypter.encrypt(bytes(blob)) for blob in blobs if len(blob) > 0]

        # Dump blob sizes first
        blob_sizes = IntListMapper.forward([len(blob) for blob in blobs])
        blob_sizes = self._crypter.encrypt(blob_sizes)
        n_blob_sizes = len(blob_sizes)
        self._fp.write(struct.pack(f"<I{n_blob_sizes}s", n_blob_sizes, blob_sizes))

        # Dump blobs
        for blob in blobs:
            items_bytes = blob
            n_items_bytes = len(items_bytes)
            self._fp.write(struct.pack(f"<{n_items_bytes}s", items_bytes))

    def write(self, pool: Optional[Pool] = None, *args, **kwargs):
        try:
            if not self._crypter.is_null_crypter:
                log.info(
                    "Please see the generated encryption secret key file in the following path.\n"
                    f"{self.secret_key_file}\n"
                    "It must be kept it separate from the dataset to protect your dataset safely. "
                    "You also need it to import the encrpted dataset in later, so that be careful not to lose."
                )

                with open(self.secret_key_file, "w") as fp:
                    fp.write(self._crypter.key.decode())

            with open(self.ann_file, "wb") as fp:
                self._fp = fp
                self._sign()
                self._dump_version()
                self._dump_encryption_field()
                self._dump_info()
                self._dump_categories()
                self._dump_media_type()
                self._dump_items(pool)
        finally:
            self._fp = None


[docs] class EncryptionAction(argparse.Action): def __init__(self, option_strings, dest, **kwargs): super().__init__(option_strings, dest, nargs=0, **kwargs) def __call__(self, parser, namespace, values, option_string=None): encryption = True if option_string in self.option_strings else False if encryption: key = Crypter.gen_key() else: key = None setattr(namespace, "encryption_key", key) delattr(namespace, self.dest)
[docs] class DatumaroBinaryExporter(DatumaroExporter): DEFAULT_IMAGE_EXT = DatumaroBinaryPath.IMAGE_EXT PATH_CLS = DatumaroBinaryPath
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "--encryption", action=EncryptionAction, default=False, help="Encrypt your dataset with the auto-generated secret key.", ) parser.add_argument( "--no-media-encryption", action="store_true", help="Only encrypt the annotation file, not media files. " 'This option is effective only if "--encryption" is enabled.', ) parser.add_argument( "--num-workers", type=int, default=0, help="The number of multi-processing workers for export. " "If num_workers = 0, do not use multiprocessing (default: %(default)s).", ) return parser
def __init__( self, extractor: IDataset, save_dir: str, *, save_media: Optional[bool] = None, image_ext: Optional[str] = None, default_image_ext: Optional[str] = None, save_dataset_meta: bool = False, ctx: Optional[ExportContext] = None, encryption_key: Optional[bytes] = None, no_media_encryption: bool = False, encryption: bool = False, num_workers: int = 0, max_blob_size: int = DatumaroBinaryPath.MAX_BLOB_SIZE, **kwargs, ): """ Parameters ---------- encryption_key If provided, the dataset is encrypted with this key for export. no_media_encryption If true and encryption is enabled, do not encrypt media files and only encrypt annotation files. encryption If true and encryption_key is None, generate a random secret key. num_workers The number of multi-processing workers for export. If num_workers = 0, do not use multiprocessing. max_blob_size The maximum size of DatasetItem serialization blob. Changing from the default is not recommended. """ if encryption and encryption_key is None: encryption_key = Crypter.gen_key() self._encryption_key = encryption_key if not save_media: no_media_encryption = True self._no_media_encryption = no_media_encryption if num_workers < 0: raise DatumaroError( f"num_workers should be non-negative but num_workers={num_workers}." ) self._num_workers = num_workers self._max_blob_size = max_blob_size self._crypter = Crypter(encryption_key) if encryption_key is not None else NULL_CRYPTER super().__init__( extractor, save_dir, save_media=save_media, image_ext=image_ext, default_image_ext=default_image_ext, save_dataset_meta=save_dataset_meta, ctx=ctx, )
[docs] def create_writer( self, subset: str, images_dir: str, pcd_dir: str, video_dir: str ) -> _SubsetWriter: export_context = ExportContextComponent( save_dir=self._save_dir, save_media=self._save_media, images_dir=images_dir, pcd_dir=pcd_dir, video_dir=video_dir, crypter=self._crypter, image_ext=self._image_ext, default_image_ext=self._default_image_ext, ) if osp.sep in subset: raise PathSeparatorInSubsetNameError(subset) return _SubsetWriter( context=self, subset=subset, ann_file=osp.join(self._annotations_dir, subset + self.PATH_CLS.ANNOTATION_EXT), export_context=export_context, secret_key_file=osp.join(self._save_dir, self.PATH_CLS.SECRET_KEY_FILE), no_media_encryption=self._no_media_encryption, max_blob_size=self._max_blob_size, )
def _apply_impl(self, *args, **kwargs): if self._num_workers == 0: return super()._apply_impl() with Pool(processes=self._num_workers) as pool: return super()._apply_impl(pool)