# Copyright (C) 2020-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import logging as log
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union
from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.dataset_base import (
DEFAULT_SUBSET_NAME,
CategoriesInfo,
DatasetBase,
DatasetInfo,
DatasetItem,
IDataset,
)
from datumaro.components.dataset_item_storage import (
DatasetItemStorage,
DatasetItemStorageDatasetView,
ItemStatus,
)
from datumaro.components.errors import (
CategoriesRedefinedError,
ConflictingCategoriesError,
DatasetInfosRedefinedError,
MediaTypeError,
NotAvailableError,
RepeatedItemError,
)
from datumaro.components.importer import _ImportFail
from datumaro.components.media import MediaElement
from datumaro.components.transformer import ItemTransform, Transform
from datumaro.util import is_method_redefined
__all__ = ["DatasetPatch", "DatasetStorage"]
[docs]
class DatasetPatch:
[docs]
class DatasetPatchWrapper(DatasetItemStorageDatasetView):
# The purpose of this class is to indicate that the input dataset is
# a patch and autofill patch info in Exporter
def __init__(self, patch: "DatasetPatch", parent: IDataset):
super().__init__(
patch.data,
infos=parent.infos(),
categories=parent.categories(),
media_type=parent.media_type(),
ann_types=parent.ann_types(),
)
self.patch = patch
[docs]
def subsets(self):
return {s: self.get_subset(s) for s in self.patch.updated_subsets}
def __init__(
self,
data: DatasetItemStorage,
infos: DatasetInfo,
categories: CategoriesInfo,
updated_items: Dict[Tuple[str, str], ItemStatus],
updated_subsets: Dict[str, ItemStatus] = None,
):
self.data = data
self.infos = infos
self.categories = categories
self.updated_items = updated_items
self._updated_subsets = updated_subsets
@property
def updated_subsets(self) -> Dict[str, ItemStatus]:
if self._updated_subsets is None:
self._updated_subsets = {s: ItemStatus.modified for s in self.data.subsets()}
return self._updated_subsets
def __contains__(self, x: Union[DatasetItem, Tuple[str, str]]) -> bool:
return x in self.data
[docs]
def as_dataset(self, parent: IDataset) -> IDataset:
return __class__.DatasetPatchWrapper(self, parent)
class _StackedTransform(Transform):
def __init__(self, source: IDataset, transforms: List[Transform]):
super().__init__(source)
self.is_local = True
self.transforms: List[Transform] = []
self.malformed_transform_indices: Dict[int, Exception] = {}
for idx, transform in enumerate(transforms):
try:
source = transform[0](source, *transform[1], **transform[2])
except Exception as e:
self.malformed_transform_indices[idx] = e
self.transforms.append(source)
if self.is_local and not isinstance(source, ItemTransform):
self.is_local = False
def transform_item(self, item: DatasetItem) -> DatasetItem:
for t in self.transforms:
if item is None:
break
item = t.transform_item(item)
return item
def __iter__(self) -> Iterator[DatasetItem]:
yield from self.transforms[-1]
def infos(self) -> DatasetInfo:
return self.transforms[-1].infos()
def categories(self) -> CategoriesInfo:
return self.transforms[-1].categories()
def media_type(self) -> Type[MediaElement]:
return self.transforms[-1].media_type()
def ann_types(self) -> Set[AnnotationType]:
return self.transforms[-1].ann_types()
[docs]
class DatasetStorage(IDataset):
def __init__(
self,
source: Union[IDataset, DatasetItemStorage],
infos: Optional[DatasetInfo] = None,
categories: Optional[CategoriesInfo] = None,
media_type: Optional[Type[MediaElement]] = None,
ann_types: Optional[Set[AnnotationType]] = None,
):
if source is None and categories is None:
categories = {}
elif isinstance(source, IDataset) and categories is not None:
raise ValueError("Can't use both source and categories")
self._categories = categories
if source is None and infos is None:
infos = {}
elif isinstance(source, IDataset) and infos is not None:
raise ValueError("Can't use both source and infos")
self._infos = infos
if media_type:
pass
elif isinstance(source, IDataset) and source.media_type():
media_type = source.media_type()
else:
raise ValueError("Media type must be provided for a dataset")
assert issubclass(media_type, MediaElement)
self._media_type = media_type
if ann_types:
pass
elif isinstance(source, IDataset) and source.ann_types():
ann_types = source.ann_types()
else:
ann_types = set()
self._ann_types = ann_types
# Possible combinations:
# 1. source + storage
# - Storage contains a patch to the Source data.
# 2. no source + storage
# - a dataset created from scratch
# - a dataset from a source or transform, which was cached
if isinstance(source, DatasetItemStorage):
self._source = None
self._storage = source
else:
self._source = source
self._storage = DatasetItemStorage() # patch or cache
self._transforms = [] # A stack of postponed transforms
# Describes changes in the dataset since initialization
self._updated_items = {} # (id, subset) -> ItemStatus
self._flush_changes = False # Deferred flush indicator
self._length = len(self._storage) if self._source is None else None
[docs]
def is_cache_initialized(self) -> bool:
return self._source is None and not self._transforms
@property
def _is_unchanged_wrapper(self) -> bool:
return self._source is not None and self._storage.is_empty() and not self._transforms
[docs]
def init_cache(self) -> None:
if not self.is_cache_initialized():
for _ in self._iter_init_cache():
pass
def _iter_init_cache(self) -> Iterable[DatasetItem]:
try:
# Can't just return from the method, because it won't add exception handling
# It covers cases when we save the null error handler in the source
for item in self._iter_init_cache_unchecked():
yield item
except _ImportFail as e:
raise e.__cause__
def _iter_init_cache_unchecked(self) -> Iterable[DatasetItem]:
# Merges the source, source transforms and patch, caches the result
# and provides an iterator for the resulting item sequence.
#
# If iterated in parallel, the result is undefined.
# If storage is changed during iteration, the result is undefined.
#
# TODO: can potentially be optimized by sharing
# the cache between parallel consumers and introducing some kind of lock
#
# Cases:
# 1. Has source and patch
# 2. Has source, transforms and patch
# a. Transforms affect only an item (i.e. they are local)
# b. Transforms affect whole dataset
#
# The patch is always applied on top of the source / transforms stack.
def _update_status(item_id, new_status: ItemStatus):
current_status = self._updated_items.get(item_id)
if current_status is None:
self._updated_items[item_id] = new_status
elif new_status == ItemStatus.removed:
if current_status == ItemStatus.added:
self._updated_items.pop(item_id)
else:
self._updated_items[item_id] = ItemStatus.removed
elif new_status == ItemStatus.modified:
if current_status != ItemStatus.added:
self._updated_items[item_id] = ItemStatus.modified
elif new_status == ItemStatus.added:
if current_status != ItemStatus.added:
self._updated_items[item_id] = ItemStatus.modified
else:
assert False, "Unknown status %s" % new_status
def _add_ann_types(item: DatasetItem):
for ann in item.annotations:
if ann.type == AnnotationType.hash_key:
continue
self._ann_types.add(ann.type)
media_type = self._media_type
patch = self._storage # must be empty after transforming
cache = DatasetItemStorage()
source = self._source or DatasetItemStorageDatasetView(
self._storage,
infos=self._infos,
categories=self._categories,
media_type=media_type,
ann_types=self._ann_types,
)
transform = None
if self._transforms:
transform = _StackedTransform(source, self._transforms)
if transform.is_local:
# An optimized way to find modified items:
# Transform items inplace and analyze transform outputs
pass
else:
# A generic way to find modified items:
# Collect all the dataset original ids and compare
# with transform outputs.
# TODO: introduce DatasetBase.items() / .ids() to avoid extra
# dataset traversals?
old_ids = set((item.id, item.subset) for item in source)
source = transform
if not issubclass(transform.media_type(), media_type):
# TODO: make it statically available
raise MediaTypeError(
"Transforms are not allowed to change media " "type of dataset items"
)
self._drop_malformed_transforms(transform.malformed_transform_indices)
i = -1
for i, item in enumerate(source):
if item.media and not isinstance(item.media, media_type):
raise MediaTypeError(
"Unexpected media type of a dataset item '%s'. "
"Expected '%s', actual '%s' " % (item.id, media_type, type(item.media))
)
if transform and transform.is_local:
old_id = (item.id, item.subset)
item = transform.transform_item(item)
item_id = (item.id, item.subset) if item else None
if item_id in cache:
raise RepeatedItemError(item_id)
if item in patch:
# Apply changes from the patch
item = patch.get(*item_id)
elif transform and not self._flush_changes:
# Find changes made by transforms, if not overridden by patch
if transform.is_local:
if not item:
_update_status(old_id, ItemStatus.removed)
elif old_id != item_id:
_update_status(old_id, ItemStatus.removed)
_update_status(item_id, ItemStatus.added)
else:
# Consider all items modified without comparison,
# because such comparison would be very expensive
_update_status(old_id, ItemStatus.modified)
else:
if item:
if item_id not in old_ids:
_update_status(item_id, ItemStatus.added)
else:
_update_status(item_id, ItemStatus.modified)
if not item:
continue
cache.put(item)
yield item
_add_ann_types(item)
if i == -1:
cache = patch
for item in patch:
if not self._flush_changes:
_update_status((item.id, item.subset), ItemStatus.added)
yield item
_add_ann_types(item)
else:
for item in patch:
if item in cache: # already processed
continue
if not self._flush_changes:
_update_status((item.id, item.subset), ItemStatus.added)
cache.put(item)
yield item
_add_ann_types(item)
if not self._flush_changes and transform and not transform.is_local:
# Mark removed items that were not produced by transforms
for old_id in old_ids:
if old_id not in self._updated_items:
self._updated_items[old_id] = ItemStatus.removed
self._storage = cache
self._length = len(cache)
if transform:
source_cat = transform.categories()
else:
source_cat = source.categories()
if source_cat is not None:
# Don't need to override categories if already defined
self._categories = source_cat
if transform:
source_infos = transform.infos()
else:
source_infos = source.infos()
if source_infos is not None:
self._infos = source_infos
self._source = None
self._transforms = []
if self._flush_changes:
self._flush_changes = False
self._updated_items = {}
def __iter__(self) -> Iterator[DatasetItem]:
if self._is_unchanged_wrapper:
yield from self._iter_init_cache()
else:
yield from self._merged()
def _merged(self) -> IDataset:
if self._is_unchanged_wrapper:
return self._source
elif self._source is not None:
self.init_cache()
return DatasetItemStorageDatasetView(
self._storage,
infos=self._infos,
categories=self._categories,
media_type=self._media_type,
ann_types=self._ann_types,
)
def __len__(self) -> int:
if self._length is None:
self.init_cache()
return self._length
[docs]
def infos(self) -> DatasetInfo:
if self.is_cache_initialized():
return self._infos
elif self._infos is not None:
return self._infos
elif any(is_method_redefined("infos", Transform, t[0]) for t in self._transforms):
self.init_cache()
return self._infos
else:
return self._source.infos()
[docs]
def define_infos(self, infos: DatasetInfo):
if self._infos or self._source is not None:
raise DatasetInfosRedefinedError()
self._infos = infos
[docs]
def categories(self) -> CategoriesInfo:
if self.is_cache_initialized():
return self._categories
elif self._categories is not None:
return self._categories
elif any(is_method_redefined("categories", Transform, t[0]) for t in self._transforms):
self.init_cache()
return self._categories
else:
return self._source.categories()
[docs]
def define_categories(self, categories: CategoriesInfo):
if self._categories or self._source is not None:
raise CategoriesRedefinedError()
self._categories = categories
[docs]
def ann_types(self) -> Set[AnnotationType]:
return self._ann_types
[docs]
def put(self, item: DatasetItem) -> None:
if item.media and not isinstance(item.media, self._media_type):
raise MediaTypeError(
"Mismatching item media type '%s', "
"the dataset contains '%s' items." % (type(item.media), self._media_type)
)
ann_types = set([ann.type for ann in item.annotations])
# hash_key can be included any task
ann_types.discard(AnnotationType.hash_key)
is_new = self._storage.put(item)
if not self.is_cache_initialized() or is_new:
self._updated_items[(item.id, item.subset)] = ItemStatus.added
else:
self._updated_items[(item.id, item.subset)] = ItemStatus.modified
if is_new and not self.is_cache_initialized():
self._length = None
self._ann_types = set()
if self._length is not None:
self._length += is_new
[docs]
def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]:
id = str(id)
subset = subset or DEFAULT_SUBSET_NAME
item = self._storage.get(id, subset)
if item is None and not self.is_cache_initialized():
if self._source.get.__func__ == DatasetBase.get:
# can be improved if IDataset is ABC
self.init_cache()
item = self._storage.get(id, subset)
else:
item = self._source.get(id, subset)
if item:
self._storage.put(item)
return item
[docs]
def remove(self, id: str, subset: Optional[str] = None) -> None:
id = str(id)
subset = subset or DEFAULT_SUBSET_NAME
self._storage.remove(id, subset)
is_removed = self._updated_items.get((id, subset)) != ItemStatus.removed
if is_removed:
self._updated_items[(id, subset)] = ItemStatus.removed
if is_removed and not self.is_cache_initialized():
self._length = None
self._ann_types = set()
if self._length is not None:
self._length -= is_removed
[docs]
def get_subset(self, name: str) -> IDataset:
return self._merged().get_subset(name)
[docs]
def subsets(self) -> Dict[str, IDataset]:
# TODO: check if this can be optimized in case of transforms
# and other cases
return self._merged().subsets()
[docs]
def get_annotated_items(self) -> int:
return self._storage.get_annotated_items()
[docs]
def get_annotations(self) -> int:
return self._storage.get_annotations()
[docs]
def get_datasetitem_by_path(self, path: str) -> Optional[DatasetItem]:
return self._storage.get_datasetitem_by_path(path)
[docs]
def has_updated_items(self):
return bool(self._transforms) or bool(self._updated_items)
[docs]
def get_patch(self) -> DatasetPatch:
# Patch includes only added or modified items.
# To find removed items, one needs to consult updated_items list.
if self._transforms:
self.init_cache()
# The current patch (storage)
# - can miss some removals done so we add them manually
# - can include items than not in the patch
# (e.g. an item could get there after source was cached)
# So we reconstruct the patch instead of copying storage.
patch = DatasetItemStorage()
for (item_id, subset), status in self._updated_items.items():
if status is ItemStatus.removed:
patch.remove(item_id, subset)
else:
patch.put(self._storage.get(item_id, subset))
return DatasetPatch(
patch, infos=self._infos, categories=self._categories, updated_items=self._updated_items
)
[docs]
def flush_changes(self):
self._updated_items = {}
if not (self.is_cache_initialized() or self._is_unchanged_wrapper):
self._flush_changes = True
[docs]
def update(self, source: Union[DatasetPatch, IDataset, Iterable[DatasetItem]]):
# TODO: provide a more efficient implementation with patch reuse
if isinstance(source, DatasetPatch):
if source.categories() != self.categories():
raise ConflictingCategoriesError()
for item_id, status in source.updated_items.items():
if status == ItemStatus.removed:
self.remove(*item_id)
else:
self.put(source.data.get(*item_id))
elif isinstance(source, IDataset):
from datumaro.plugins.transforms import ProjectLabels
for item in ProjectLabels(
source, self.categories().get(AnnotationType.label, LabelCategories())
):
self.put(item)
else:
for item in source:
self.put(item)
def _drop_malformed_transforms(self, malformed_transform_indices: Dict[int, Exception]) -> None:
safe_transforms = []
for idx, transform in enumerate(self._transforms):
if idx in malformed_transform_indices:
log.error(
f"Automatically drop {transform} from the transform stack because an error is raised. "
"Therefore, the dataset will not be transformed by this transformation since it is droped.",
exc_info=malformed_transform_indices[idx],
)
continue
safe_transforms += [transform]
self._transforms = safe_transforms
def __getitem__(self, idx: int) -> DatasetItem:
try:
return self._storage[idx]
except IndexError: # Data storage should be initialized
self.init_cache()
return self._storage[idx]
[docs]
class StreamSubset(IDataset):
def __init__(self, source: IDataset, subset: str) -> None:
if not source.is_stream:
raise ValueError("source should be a stream.")
self._source = source
self._subset = subset
self._length = None
def __iter__(self) -> Iterator[DatasetItem]:
for item in self._source:
if item.subset == self._subset:
yield item
def __len__(self) -> int:
if self._length is None:
self._length = sum(1 for _ in self)
return self._length
[docs]
def subsets(self) -> Dict[str, IDataset]:
raise NotAvailableError("Cannot get subsets of the subset.")
[docs]
def get_subset(self, name) -> IDataset:
raise NotAvailableError("Cannot get a subset of the subset.")
[docs]
def infos(self) -> DatasetInfo:
return self._source.infos()
[docs]
def categories(self) -> CategoriesInfo:
return self._source.categories()
[docs]
def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]:
raise NotAvailableError(
"Random access to the dataset item is not allowed in streaming. "
"You can access to the dataset item only by using its iterator."
)
[docs]
def ann_types(self) -> Set[AnnotationType]:
return self._source.ann_types()
@property
def is_stream(self) -> bool:
return True
[docs]
class StreamDatasetStorage(DatasetStorage):
def __init__(
self,
source: IDataset,
infos: Optional[DatasetInfo] = None,
categories: Optional[CategoriesInfo] = None,
media_type: Optional[Type[MediaElement]] = None,
ann_types: Optional[Set[AnnotationType]] = None,
):
if not source.is_stream:
raise ValueError("source should be a stream.")
self._subset_names = list(source.subsets().keys())
self._transform_ids_for_latest_subset_names = []
super().__init__(source, infos, categories, media_type, ann_types)
[docs]
def is_cache_initialized(self) -> bool:
log.debug("This function has no effect on streaming.")
return True
[docs]
def init_cache(self) -> None:
log.debug("This function has no effect on streaming.")
pass
@property
def stacked_transform(self) -> IDataset:
if self._transforms:
transform = _StackedTransform(self._source, self._transforms)
self._drop_malformed_transforms(transform.malformed_transform_indices)
else:
transform = self._source
self._flush_changes = True
return transform
def __iter__(self) -> Iterator[DatasetItem]:
for item in self.stacked_transform:
yield item
for ann in item.annotations:
if ann.type == AnnotationType.hash_key:
continue
self._ann_types.add(ann.type)
def __len__(self) -> int:
if self._length is None:
self._length = len(self._source)
return self._length
[docs]
def put(self, item: DatasetItem) -> None:
raise NotAvailableError("Drop-in replacement is not allowed in streaming.")
[docs]
def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]:
raise NotAvailableError(
"Random access to the dataset item is not allowed in streaming. "
"You can access to the dataset item only by using its iterator."
)
[docs]
def remove(self, id: str, subset: Optional[str] = None) -> None:
raise NotAvailableError("Drop-in removal is not allowed in streaming.")
[docs]
def get_subset(self, name: str) -> IDataset:
return self.subsets()[name]
@property
def subset_names(self):
if self._transform_ids_for_latest_subset_names != [id(t) for t in self._transforms]:
self._subset_names = {item.subset for item in self}
self._transform_ids_for_latest_subset_names = [id(t) for t in self._transforms]
return self._subset_names
[docs]
def subsets(self) -> Dict[str, IDataset]:
return {subset: StreamSubset(self, subset) for subset in self.subset_names}
[docs]
def get_annotated_items(self) -> int:
return super().get_annotated_items()
[docs]
def get_annotations(self) -> int:
return super().get_annotations()
[docs]
def get_datasetitem_by_path(self, path: str) -> Optional[DatasetItem]:
raise NotAvailableError("Get dataset item by path is not allowed in streaming.")
[docs]
def get_patch(self):
raise NotAvailableError("Get patch is not allowed in streaming.")
[docs]
def flush_changes(self):
raise NotAvailableError("Flush changes is not allowed in streaming.")
[docs]
def update(self, source: Union[DatasetPatch, IDataset, Iterable[DatasetItem]]):
raise NotAvailableError("Update is not allowed in streaming.")
[docs]
def infos(self) -> DatasetInfo:
return self.stacked_transform.infos()
[docs]
def categories(self) -> CategoriesInfo:
return self.stacked_transform.categories()
@property
def is_stream(self) -> bool:
return True