Source code for datumaro.plugins.data_formats.arrow.base

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import struct
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Set, Type

import pyarrow as pa

from datumaro.components.annotation import AnnotationType, Categories
from datumaro.components.dataset_base import (
    CategoriesInfo,
    DatasetBase,
    DatasetInfo,
    DatasetItem,
    IDataset,
    SubsetBase,
)
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image, MediaElement, MediaType
from datumaro.components.merge.extractor_merger import check_identicalness
from datumaro.plugins.data_formats.arrow.format import DatumaroArrow
from datumaro.plugins.data_formats.datumaro.base import JsonReader
from datumaro.plugins.data_formats.datumaro_binary.mapper.common import DictMapper
from datumaro.util.definitions import DEFAULT_SUBSET_NAME

from .mapper.dataset_item import DatasetItemMapper


[docs] class ArrowSubsetBase(SubsetBase): __not_plugin__ = True def __init__( self, lookup: Dict[str, DatasetItem], infos: Dict[str, Any], categories: Dict[AnnotationType, Categories], subset: str, media_type: Type[MediaElement] = Image, ann_types: Set[AnnotationType] = None, ): super().__init__( length=len(lookup), subset=subset, media_type=media_type, ann_types=ann_types, ctx=None ) self._lookup = lookup self._infos = infos self._categories = categories def __iter__(self) -> Iterator[DatasetItem]: for item in self._lookup.values(): yield item def __len__(self) -> int: return len(self._lookup)
[docs] def get(self, item_id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: if subset != self._subset: return None try: return self._lookup[item_id] except KeyError: return None
[docs] @dataclass(frozen=True) class Metadata: infos: Dict categories: Dict media_type: Type[MediaElement]
[docs] class ArrowBase(DatasetBase): def __init__( self, root_path: str, *, file_paths: List[str], ctx: Optional[ImportContext] = None, ): self._root_path = root_path tables = [pa.ipc.open_file(pa.memory_map(path, "r")).read_all() for path in file_paths] metadatas = [self._load_schema_metadata(table) for table in tables] table = pa.concat_tables(tables) subsets = table.column(DatumaroArrow.SUBSET_FIELD).unique().to_pylist() media_type = check_identicalness([metadata.media_type for metadata in metadatas]) super().__init__(length=len(table), subsets=subsets, media_type=media_type, ctx=ctx) self._infos = check_identicalness([metadata.infos for metadata in metadatas]) self._categories = check_identicalness([metadata.categories for metadata in metadatas]) self._init_cache(file_paths, subsets) @staticmethod def _load_schema_metadata(table: pa.Table) -> Metadata: schema = table.schema _infos, _ = DictMapper.backward(schema.metadata.get(b"infos", b"\x00\x00\x00\x00")) infos = JsonReader._load_infos({"infos": _infos}) _categories, _ = DictMapper.backward( schema.metadata.get(b"categories", b"\x00\x00\x00\x00") ) categories = JsonReader._load_categories({"categories": _categories}) (media_type,) = struct.unpack("<I", schema.metadata.get(b"media_type", b"\x00\x00\x00\x00")) media_type = MediaType(media_type).media return Metadata(infos=infos, categories=categories, media_type=media_type)
[docs] def infos(self) -> DatasetInfo: return self._infos
[docs] def categories(self) -> CategoriesInfo: return self._categories
def __iter__(self) -> Iterator[DatasetItem]: for lookup in self._lookup.values(): for item in lookup.values(): yield item def _init_cache(self, file_paths: List[str], subsets: List[str]): self._lookup: Dict[str, Dict[str, DatasetItem]] = {subset: {} for subset in subsets} total = len(self) cnt = 0 pbar = self._ctx.progress_reporter pbar.start(total=total, desc="Importing") ann_types = set() for table_path in file_paths: with pa.OSFile(table_path, "r") as source: with pa.ipc.open_file(source) as reader: table = reader.read_all() for idx in range(len(table)): item = DatasetItemMapper.backward(idx, table, table_path) self._lookup[item.subset][item.id] = item for ann in item.annotations: ann_types.add(ann.type) pbar.report_status(cnt) cnt += 1 self._ann_types = ann_types self._subsets = { subset: ArrowSubsetBase( lookup=lookup, infos=self._infos, categories=self._categories, subset=self._subsets, media_type=self._media_type, ann_types=self._ann_types, ) for subset, lookup in self._lookup.items() } pbar.finish()
[docs] def get(self, item_id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: subset = subset or DEFAULT_SUBSET_NAME try: return self._lookup[subset][item_id] except KeyError: return None
@property def lookup(self) -> Dict[str, Dict[str, int]]: return self._lookup
[docs] def subsets(self) -> Dict[str, IDataset]: return self._subsets
[docs] def get_subset(self, name: str) -> IDataset: return self._subsets[name]