Source code for otx.core.ov.ops.arithmetics
"""Arithmetics-related codes for otx.core.ov.ops.arithmetics."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
import torch
from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.op import Attribute, Operation
@dataclass
class MultiplyV1Attribute(Attribute):
"""MultiplyV1Attribute class."""
auto_broadcast: str = field(default="numpy")
[docs]
@OPS.register()
class MultiplyV1(Operation[MultiplyV1Attribute]):
"""MultiplyV1 class."""
TYPE = "Multiply"
VERSION = "opset1"
ATTRIBUTE_FACTORY = MultiplyV1Attribute
[docs]
def forward(self, input_0, input_1):
"""MultiplyV1's forward function."""
broadcast = self.attrs.auto_broadcast
if broadcast == "none":
assert input_0.shape == input_1.shape
return input_0 * input_1
if broadcast == "numpy":
return input_0 * input_1
raise NotImplementedError
@dataclass
class DivideV1Attribute(Attribute):
"""DivideV1Attribute class."""
m_pythondiv: bool = field(default=True)
auto_broadcast: str = field(default="numpy")
[docs]
@OPS.register()
class DivideV1(Operation[DivideV1Attribute]):
"""DivideV1 class."""
TYPE = "Divide"
VERSION = "opset1"
ATTRIBUTE_FACTORY = DivideV1Attribute
[docs]
def forward(self, input_0, input_1):
"""DivideV1's forward function."""
broadcast = self.attrs.auto_broadcast
if broadcast == "none":
assert input_0.shape == input_1.shape
output = input_0 / input_1
elif broadcast == "numpy":
output = input_0 / input_1
else:
raise NotImplementedError
non_integer_types = [torch.float16, torch.float32, torch.float64, torch.bool]
if self.attrs.m_pythondiv and input_0.dtype not in non_integer_types and input_1.dtype not in non_integer_types:
output = output.type(input_0.dtype)
return output
@dataclass
class AddV1Attribute(Attribute):
"""AddV1Attribute class."""
auto_broadcast: str = field(default="numpy")
[docs]
@OPS.register()
class AddV1(Operation[AddV1Attribute]):
"""AddV1 class."""
TYPE = "Add"
VERSION = "opset1"
ATTRIBUTE_FACTORY = AddV1Attribute
[docs]
def forward(self, input_0, input_1):
"""AddV1's forward function."""
broadcast = self.attrs.auto_broadcast
if broadcast == "none":
assert input_0.shape == input_1.shape
return input_0 + input_1
if broadcast == "numpy":
return input_0 + input_1
raise NotImplementedError
@dataclass
class SubtractV1Attribute(Attribute):
"""SubtractV1Attribute class."""
auto_broadcast: str = field(default="numpy")
[docs]
@OPS.register()
class SubtractV1(Operation[SubtractV1Attribute]):
"""SubtractV1 class."""
TYPE = "Subtract"
VERSION = "opset1"
ATTRIBUTE_FACTORY = SubtractV1Attribute
[docs]
def forward(self, input_0, input_1):
"""SubtractV1's forward function."""
broadcast = self.attrs.auto_broadcast
if broadcast == "none":
assert input_0.shape == input_1.shape
return input_0 - input_1
if broadcast == "numpy":
return input_0 - input_1
raise NotImplementedError
@dataclass
class TanV0Attribute(Attribute):
"""TanV0Attribute class."""
pass # pylint: disable=unnecessary-pass
[docs]
@OPS.register()
class TanV0(Operation[TanV0Attribute]):
"""TanV0 class."""
TYPE = "Tan"
VERSION = "opset1"
ATTRIBUTE_FACTORY = TanV0Attribute
[docs]
def forward(self, inputs):
"""TanV0's forward function."""
return torch.tan(inputs)