Source code for otx.core.ov.ops.builder
"""OPS (OperationRegistry) module for otx.core.ov.ops.builder."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
from otx.core.ov.registry import Registry
[docs]
class OperationRegistry(Registry):
"""OperationRegistry class."""
def __init__(self, name, add_name_as_attr=False):
super().__init__(name, add_name_as_attr)
self._registry_dict_by_type = {}
[docs]
def register(self, name: Optional[Any] = None):
"""Register function from name."""
def wrap(obj):
layer_name = name
if layer_name is None:
layer_name = obj.__name__
layer_type = obj.TYPE
layer_version = obj.VERSION
assert layer_type and layer_version
if self._add_name_as_attr:
setattr(obj, self.REGISTERED_NAME_ATTR, layer_name)
self._register(obj, layer_name, layer_type, layer_version)
return obj
return wrap
def _register(self, obj, name, types, version):
"""Register function from obj and obj name."""
super()._register(obj, name)
if types not in self._registry_dict_by_type:
self._registry_dict_by_type[types] = {}
if version in self._registry_dict_by_type[types]:
raise KeyError(f"{version} is already registered in {types}")
self._registry_dict_by_type[types][version] = obj
[docs]
def get_by_name(self, name):
"""Get obj from name."""
return self.get(name)
[docs]
def get_by_type_version(self, types, version):
"""Get obj from type and version."""
if types not in self._registry_dict_by_type:
raise KeyError(f"type {types} is not registered in {self._name}")
if version not in self._registry_dict_by_type[types]:
raise KeyError(f"version {version} is not registered in {types} of {self._name}")
return self._registry_dict_by_type[types][version]
OPS = OperationRegistry("ov ops")