# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import pyarrow as pa
from datumaro.components.errors import DatumaroError
from datumaro.components.media import Image, MediaElement, MediaType, PointCloud
from datumaro.plugins.data_formats.datumaro_binary.mapper.common import Mapper
from datumaro.util.image import decode_image, encode_image, load_image
[docs]
class ImageMapper(MediaElementMapper):
MEDIA_TYPE = MediaType.IMAGE
AVAILABLE_SCHEMES = ("AS-IS", "PNG", "TIFF", "JPEG/95", "JPEG/75", "NONE")
[docs]
@classmethod
def encode(cls, obj: Image, scheme: str = "PNG") -> Optional[bytes]:
if scheme is None or scheme == "NONE":
return None
if scheme == "AS-IS":
_bytes = obj.bytes
if _bytes is not None:
return _bytes
# try to encode in PNG
scheme = "PNG"
options = {}
if scheme.startswith("JPEG"):
quality = int(scheme.split("/")[-1])
options["ext"] = "JPEG"
options["jpeg_quality"] = quality
elif scheme == "PNG":
options["ext"] = "PNG"
elif scheme == "TIFF":
options["ext"] = "TIFF"
else:
raise NotImplementedError
data = obj.data
if data is not None:
return encode_image(obj.data, **options)
return None
[docs]
@classmethod
def decode(
cls, path: Optional[str] = None, data: Optional[bytes] = None
) -> Optional[np.ndarray]:
if path is None and data is None:
return None
if data is not None:
return decode_image(data, np.uint8)
if path is not None:
return load_image(path, np.uint8)
[docs]
@classmethod
def forward(
cls, obj: Image, encoder: Union[str, Callable[[Image], bytes]] = "PNG"
) -> Dict[str, Any]:
out = super().forward(obj)
_bytes = encoder(obj) if isinstance(encoder, Callable) else cls.encode(obj, scheme=encoder)
path = None if _bytes is not None else getattr(obj, "path", None)
out["image"] = {
"has_bytes": _bytes is not None,
"bytes": _bytes,
"path": path,
"size": obj.size,
}
return out
[docs]
@classmethod
def backward(
cls,
media_struct: pa.StructScalar,
idx: int,
table: pa.Table,
table_path: str,
) -> Image:
image_struct = media_struct.get("image")
if path := image_struct.get("path").as_py():
return Image.from_file(
path=path,
size=image_struct.get("size").as_py(),
)
return Image.from_bytes(
data=lambda: pa.ipc.open_file(pa.memory_map(table_path, "r"))
.read_all()
.column("media")[idx]
.get("image")
.get("bytes")
.as_py(),
size=image_struct.get("size").as_py(),
)
# TODO: share binary for extra images
[docs]
class PointCloudMapper(MediaElementMapper):
MEDIA_TYPE = MediaType.POINT_CLOUD
B64_PREFIX = "//B64_ENCODED//"
[docs]
@classmethod
def forward(
cls, obj: PointCloud, encoder: Union[str, Callable[[PointCloud], bytes]] = "PNG"
) -> Dict[str, Any]:
out = super().forward(obj)
if isinstance(encoder, Callable):
_bytes = encoder(obj)
elif encoder != "NONE":
_bytes = obj.data
else:
_bytes = None
path = None if _bytes is not None else getattr(obj, "path", None)
out["point_cloud"] = {
"has_bytes": _bytes is not None,
"bytes": _bytes,
"path": path,
"extra_images": [
ImageMapper.forward(img, encoder=encoder)["image"] for img in obj.extra_images
],
}
return out
[docs]
@classmethod
def backward(
cls,
media_struct: pa.StructScalar,
idx: int,
table: pa.Table,
table_path: str,
) -> PointCloud:
point_cloud_struct = media_struct.get("point_cloud")
extra_images = [
ImageMapper.backward_extra_image(image_struct, idx, table, extra_image_idx)
for extra_image_idx, image_struct in enumerate(point_cloud_struct.get("extra_images"))
]
if path := point_cloud_struct.get("path").as_py():
return PointCloud.from_file(path=path, extra_images=extra_images)
return PointCloud.from_bytes(
data=pa.ipc.open_file(pa.memory_map(table_path, "r"))
.read_all()
.column("media")[idx]
.get("point_cloud")
.get("bytes")
.as_py(),
extra_images=extra_images,
)