Source code for datumaro.components.registry
# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import defaultdict
from inspect import isclass
from typing import (
Dict,
Generator,
Generic,
Iterable,
Iterator,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import DatasetBase, SubsetBase
from datumaro.components.exporter import Exporter
from datumaro.components.generator import DatasetGenerator
from datumaro.components.importer import Importer
from datumaro.components.launcher import Launcher
from datumaro.components.lazy_plugin import LazyPlugin
from datumaro.components.transformer import ItemTransform, Transform
from datumaro.components.validator import Validator
T = TypeVar("T")
[docs]
class Registry(Generic[T]):
def __init__(self):
self._items: Dict[str, T] = {}
[docs]
def register(self, name: str, value: T) -> T:
self._items[name] = value
return value
[docs]
def unregister(self, name: str) -> Optional[T]:
return self._items.pop(name, None)
[docs]
def get(self, key: str) -> T:
"""Returns a class or a factory function"""
return self._items[key]
def __getitem__(self, key: str) -> T:
return self.get(key)
def __contains__(self, key) -> bool:
return key in self._items
def __iter__(self) -> Iterator[str]:
return iter(self._items)
[docs]
def items(self) -> Generator[Tuple[str, T], None, None]:
for key in self:
yield key, self.get(key)
[docs]
class PluginRegistry(Registry[Type[CliPlugin]]):
_ACCEPT: Type[CliPlugin] = None
_SKIP: Optional[Iterable[Type[CliPlugin]]] = None
_DECLINE: Optional[Type[CliPlugin]] = None
def __init__(self):
super().__init__()
if self._ACCEPT is None:
raise NotImplementedError(
f"{self.__class__.__name__} requires an _ACCEPT class attribute"
" to specify the accepted type of stored instances."
)
def _filter(self, t):
skip = {self._SKIP} if isclass(self._SKIP) else set(self._SKIP or [])
skip = tuple(skip | set((self._ACCEPT,)))
if (
not issubclass(t, self._ACCEPT)
or t in skip
or (self._DECLINE and issubclass(t, self._DECLINE))
):
return False
if getattr(t, "__not_plugin__", None):
return False
return True
[docs]
def get(self, key: str) -> Type[CliPlugin]:
"""Returns a class or a factory function"""
item = self._items[key]
if issubclass(item, LazyPlugin):
return item.get_plugin_cls()
return item
[docs]
def batch_register(self, values: Iterable[Type[CliPlugin]]):
for v in values:
if not self._filter(v):
continue
self.register(v.NAME, v)
[docs]
class DatasetBaseRegistry(PluginRegistry):
_ACCEPT = DatasetBase
_SKIP = (SubsetBase, Transform, ItemTransform)
_DECLINE = Transform
[docs]
class ImporterRegistry(PluginRegistry):
_ACCEPT = Importer
def __init__(self):
super().__init__()
self.extension_groups = defaultdict(list)
[docs]
def register(
self, name: str, value: Union[Type[Importer], LazyPlugin]
) -> Union[Type[Importer], LazyPlugin]:
super().register(name, value)
if issubclass(value, LazyPlugin):
file_extensions = value.METADATA["file_extensions"]
else:
file_extensions = value.get_file_extensions()
for extension in file_extensions:
self.extension_groups[extension].append((name, value))
return value
[docs]
class LauncherRegistry(PluginRegistry):
_ACCEPT = Launcher
[docs]
class ExporterRegistry(PluginRegistry):
_ACCEPT = Exporter
[docs]
class GeneratorRegistry(PluginRegistry):
_ACCEPT = DatasetGenerator
[docs]
class ValidatorRegistry(PluginRegistry):
_ACCEPT = Validator