Source code for otx.core.ov.registry
"""Registry Class for otx.core.ov."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Optional
[docs]
class Registry:
"""Registry Class for OMZ model."""
REGISTERED_NAME_ATTR = "_registered_name"
def __init__(self, name, add_name_as_attr=False):
self._name = name
self._registry_dict = {}
self._add_name_as_attr = add_name_as_attr
@property
def registry_dict(self) -> Dict[Any, Any]:
"""Dictionary of registered module."""
return self._registry_dict
def _register(self, obj: Any, name: Any):
"""Register obj with name."""
if name in self._registry_dict:
raise KeyError(f"{name} is already registered in {self._name}")
self._registry_dict[name] = obj
[docs]
def register(self, name: Optional[Any] = None):
"""Register from name."""
def wrap(obj):
cls_name = name
if cls_name is None:
cls_name = obj.__name__
if self._add_name_as_attr:
setattr(obj, self.REGISTERED_NAME_ATTR, cls_name)
self._register(obj, cls_name)
return obj
return wrap
[docs]
def get(self, key: Any) -> Any:
"""Get from module name (key)."""
if key not in self._registry_dict:
self._key_not_found(key)
return self._registry_dict[key]
def _key_not_found(self, key: Any):
"""Raise KeyError when key not founded."""
raise KeyError(f"{key} is not found in {self._name}")
def __contains__(self, item):
"""Check containing of item."""
return item in self._registry_dict.values()