Source code for otx.core.ov.ops.op

"""Operation-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import re
from dataclasses import dataclass, fields
from typing import Generic, Optional, Tuple, Type, TypeVar, Union

import torch

from ..utils import get_op_name  # type: ignore[attr-defined]
from .utils import get_dynamic_shape


[docs] @dataclass class Attribute: """Attribute class.""" shape: Optional[Union[Tuple[Tuple[int]], Tuple[int]]] def __post_init__(self): """Attribute's post-init function.""" if self.shape is not None and not isinstance(self.shape, tuple): raise ValueError("shape must be a tuple of ints or a tuple of tuples of ints.")
_T = TypeVar("_T", bound=Attribute)
[docs] class Operation(torch.nn.Module, Generic[_T]): # pylint: disable=abstract-method, invalid-overridden-method """Operation class.""" TYPE = "" VERSION = "" ATTRIBUTE_FACTORY: Type[Attribute] = Attribute def __init__(self, name: str, **kwargs): super().__init__() self._name = name self._attrs = self.ATTRIBUTE_FACTORY(**kwargs)
[docs] @classmethod def from_ov(cls, ov_op): """Operation's from_ov function.""" op_type = ov_op.get_type_name() op_version = ov_op.get_type_info().version_id op_name = get_op_name(ov_op) assert cls.TYPE and cls.VERSION assert op_type == cls.TYPE assert op_version == cls.VERSION attrs = ov_op.get_attributes() if "shape" not in attrs: shapes = [] for output in ov_op.outputs(): shapes.append(get_dynamic_shape(output)) shapes = tuple(tuple(shape) for shape in shapes) attrs["shape"] = shapes return cls(name=op_name, **attrs)
@property def type(self) -> str: # pylint: disable=invalid-overridden-method """Operation's type property.""" return self.TYPE @property def version(self) -> str: """Operation's version property.""" return self.VERSION @property def name(self) -> str: """Operation's name property.""" return self._name @property def attrs(self): """Operation's attrs property.""" return self._attrs @property def shape(self) -> Optional[Union[Tuple[Tuple[int]], Tuple[int]]]: """Operation's shape property.""" return self.attrs.shape def __repr__(self): """Operation's __repr__ function.""" repr_str = f"{self.__class__.__name__}(" repr_str += f"name={self.name}, " for field in fields(self.attrs): key = field.name if key == "shape": continue value = getattr(self.attrs, key) repr_str += f"{key}={value}, " repr_str = re.sub(", $", "", repr_str) repr_str += ")" return repr_str