Source code for datumaro.plugins.data_formats.arrow.importer
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
from typing import Dict, List, Optional
import pyarrow as pa
from datumaro.components.errors import DatasetImportError
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import Importer
from .format import DatumaroArrow
__all__ = ["ArrowImporter"]
[docs]
class ArrowImporter(Importer):
_FORMAT_EXT = ".arrow"
[docs]
@classmethod
def detect(
cls,
context: FormatDetectionContext,
) -> Optional[FormatDetectionConfidence]:
if context.root_path.endswith(".arrow"):
cls._verify_datumaro_arrow_format(context.root_path)
else:
for arrow_file in context.require_files("*.arrow"):
with context.probe_text_file(
arrow_file,
f"{arrow_file} is not Datumaro arrow format.",
is_binary_file=True,
) as f:
f.close()
cls._verify_datumaro_arrow_format(os.path.join(context.root_path, arrow_file))
[docs]
@classmethod
def find_sources(cls, path: str) -> List[Dict]:
def _filter(path: str) -> bool:
try:
cls._verify_datumaro_arrow_format(path)
return True
except DatasetImportError:
return False
return cls._find_sources_recursive(
path=path,
ext=cls._FORMAT_EXT,
extractor_name=cls.NAME,
file_filter=_filter,
max_depth=0,
)
[docs]
@classmethod
def find_sources_with_params(cls, path: str, **extra_params) -> List[Dict]:
sources = cls.find_sources(path)
# Merge sources into one config but multiple file_paths
return [
{
"url": path,
"format": cls.NAME,
"options": {"file_paths": [source["url"] for source in sources]},
}
]
@staticmethod
def _verify_datumaro_arrow_format(file: str) -> None:
with pa.memory_map(file, "r") as mm_file:
with pa.ipc.open_file(mm_file) as reader:
schema = reader.schema
DatumaroArrow.check_signature(schema.metadata.get(b"signature", b"").decode())
DatumaroArrow.check_version(schema.metadata.get(b"version", b"").decode())
DatumaroArrow.check_schema(schema)
[docs]
@classmethod
def get_file_extensions(cls) -> List[str]:
return [cls._FORMAT_EXT]