Source code for otx.core.ov.ops.infrastructures

"""Infrastructure-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Optional, Tuple

import numpy as np
import torch

from otx.utils.logger import get_logger

from ..utils import get_op_name  # type: ignore[attr-defined]
from .builder import OPS
from .op import Attribute, Operation
from .type_conversions import ConvertV0
from .utils import get_dynamic_shape

NODE_TYPES_WITH_WEIGHT = set(
    [
        "Convolution",
        "GroupConvolution",
        "MatMul",
        "BatchNormInference",
        "Multiply",
        "Divide",
        "Add",
        "Subtract",
    ]
)

logger = get_logger()


@dataclass
class ParameterV0Attribute(Attribute):
    """ParameterV0Attribute class."""

    element_type: Optional[str] = field(default=None)

    layout: Optional[Tuple[str]] = field(default=None)
    permute: Optional[Tuple[int]] = field(default=None)
    verify_shape: bool = field(default=True)

    def __post_init__(self):
        """ParameterV0Attribute's post-init function."""
        super().__post_init__()
        # fmt: off
        valid_element_type = [
            None,
            "u1", "u4", "u8", "u16", "u32", "u64",
            "i4", "i8", "i16", "i32", "i64", "f16", "f32", "boolean", "bf16"
        ]
        # fmt: on
        if self.element_type not in valid_element_type:
            raise ValueError(f"Invalid element_type {self.element_type}. " f"It must be one of {valid_element_type}.")


[docs] @OPS.register() class ParameterV0(Operation[ParameterV0Attribute]): """ParameterV0 class.""" TYPE = "Parameter" VERSION = "opset1" ATTRIBUTE_FACTORY = ParameterV0Attribute
[docs] def forward(self, inputs): """ParameterV0's forward function.""" # TODO: validate shape # need to handle new generated op from reshaped model if self.attrs.verify_shape: assert self.shape is not None ov_shape = self.shape[0] torch_shape = list(inputs.shape) for ov_shape_, torch_shape_ in zip(ov_shape, torch_shape): if ov_shape_ == -1: continue assert ov_shape_ == torch_shape_, f"input shape {torch_shape} does not match with ov shape {ov_shape}" if self.attrs.permute: inputs = inputs.permute(self.attrs.permute) return inputs
[docs] @classmethod def from_ov(cls, ov_op): """ParameterV0's from_ov function.""" op_type = ov_op.get_type_name() op_name = get_op_name(ov_op) op_version = ov_op.get_type_info().version_id assert cls.TYPE and cls.VERSION assert op_type == cls.TYPE assert op_version == cls.VERSION attrs = ov_op.get_attributes() if "shape" not in attrs: shapes = [] for output in ov_op.outputs(): shapes.append(get_dynamic_shape(output)) shapes = tuple(tuple(shape) for shape in shapes) attrs["shape"] = shapes layout = ov_op.get_layout() if not layout.empty: layout = layout.to_string()[1:-1].split(",") attrs["layout"] = tuple(layout) # N, C, H, W input_layout = OrderedDict( { "N": 0, "C": 1, "H": 2, "W": 3, } ) if not set(layout).symmetric_difference(input_layout.keys()): permute = [] for layout_ in layout: # N, H, W, C permute.append(input_layout[layout_]) attrs["permute"] = tuple(permute) # TODO: here, we force the batch dim to be dynamic # but this should be done when loading ov model i = layout.index("N") new_shape = [] for shape in attrs["shape"]: new_shape.append([-1 if j == i else k for j, k in enumerate(shape)]) new_shape = [tuple(shape) for shape in new_shape] attrs["shape"] = tuple(new_shape) # change shape and layout based on permute if "permute" in attrs and attrs["permute"] != (0, 1, 2, 3): assert len(attrs["shape"]) == 1 permute = [] for layout_ in input_layout.keys(): permute.append(layout.index(layout_)) new_shape = [] for shape in attrs["shape"]: new_shape.append([shape[i] for i in permute]) attrs["shape"] = tuple(tuple(shape) for shape in new_shape) attrs["layout"] = tuple(attrs["layout"][i] for i in permute) return cls(name=op_name, **attrs)
@dataclass class ResultV0Attribute(Attribute): """ResultV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class ResultV0(Operation[ResultV0Attribute]): """ResultV0 class.""" TYPE = "Result" VERSION = "opset1" ATTRIBUTE_FACTORY = ResultV0Attribute
[docs] def forward(self, inputs): """ResultV0's forward function.""" return inputs
@dataclass class ConstantV0Attribute(Attribute): """ConstantV0Attribute class.""" element_type: str offset: int = field(default=0) size: int = field(default=0) is_parameter: bool = field(default=False) def __post_init__(self): """ConstantV0Attribute's post-init function.""" super().__post_init__() # fmt: off valid_element_type = [ "u1", "u4", "u8", "u16", "u32", "u64", "i4", "i8", "i16", "i32", "i64", "f16", "f32", "boolean", "bf16" ] # fmt: on if self.element_type not in valid_element_type: raise ValueError(f"Invalid element_type {self.element_type}. " f"It must be one of {valid_element_type}.")
[docs] @OPS.register() class ConstantV0(Operation[ConstantV0Attribute]): """ConstantV0 class.""" TYPE = "Constant" VERSION = "opset1" ATTRIBUTE_FACTORY = ConstantV0Attribute def __init__(self, *args, **kwargs): data = kwargs.pop("data", None) if data is None: raise KeyError("data is not provided") assert isinstance(data, torch.Tensor) kwargs["element_type"] = ConvertV0.convert_torch_type(data.dtype) super().__init__(*args, **kwargs) if self.attrs.is_parameter: self.data = torch.nn.Parameter(data) else: self.register_buffer("data", data)
[docs] def forward(self): """ConstantV0's forward function.""" return self.data
[docs] @classmethod def from_ov(cls, ov_op): """ConstantV0's from_ov function.""" op_type = ov_op.get_type_name() op_name = get_op_name(ov_op) op_version = ov_op.get_type_info().version_id assert cls.TYPE and cls.VERSION assert op_type == cls.TYPE assert op_version == cls.VERSION attrs = ov_op.get_attributes() attrs["shape"] = tuple(attrs["shape"]) data = ov_op.get_data() if data.dtype == np.uint64: data_ = data.astype(np.int64) if not np.array_equal(data, data_): logger.warning(f"Overflow detected in {op_name}") data = torch.from_numpy(data_) elif data.dtype == np.uint16: data = torch.from_numpy(data.astype(np.int32)) else: data = torch.from_numpy(data) in_port_indices = [] op_node_types = [] for out_port in ov_op.outputs(): for in_port in list(out_port.get_target_inputs()): in_port_index = in_port.get_index() in_port_indices.append(in_port_index) node = in_port.get_node() op_node_types.append(node.get_type_name()) # FIXME: need a better way to distinguish if it is parameter or not is_parameter = False # pylint: disable=too-many-boolean-expressions if ( set(op_node_types).intersection(NODE_TYPES_WITH_WEIGHT) and len(in_port_indices) == 1 and in_port_indices[0] != 0 and data.numel() > 1 and (data.is_floating_point() or data.is_complex()) ): is_parameter = True attrs["is_parameter"] = is_parameter return cls(name=op_name, data=data, **attrs)