Source code for otx.core.ov.ops.matmuls

"""MatMul-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field

import torch

from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.op import Attribute, Operation


@dataclass
class MatMulV0Attribute(Attribute):
    """MatMulV0Attribute class."""

    transpose_a: bool = field(default=False)
    transpose_b: bool = field(default=False)


[docs] @OPS.register() class MatMulV0(Operation[MatMulV0Attribute]): """MatMulV0 class.""" TYPE = "MatMul" VERSION = "opset1" ATTRIBUTE_FACTORY = MatMulV0Attribute
[docs] def forward(self, input_a, input_b): """MatMulV0's forward function.""" if self.attrs.transpose_a: input_a = torch.transpose(input_a, -1, -2) if self.attrs.transpose_b: input_b = torch.transpose(input_b, -1, -2) return torch.matmul(input_a, input_b)
@dataclass class EinsumV7Attribute(Attribute): """EinsumV7Attribute class.""" equation: str
[docs] @OPS.register() class EinsumV7(Operation[EinsumV7Attribute]): """EinsumV7 class.""" TYPE = "Einsum" VERSION = "opset1" ATTRIBUTE_FACTORY = EinsumV7Attribute
[docs] def forward(self, *inputs): """EinsumV7's forward function.""" return torch.einsum(self.attrs.equation, *inputs)