Source code for otx.core.ov.ops.activations

"""Activation-related modules for otx.core.ov.ops.activations."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import math
from dataclasses import dataclass, field

import torch
from torch.nn import functional as F

from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.op import Attribute, Operation


@dataclass
class SoftMaxV0Attribute(Attribute):
    """SoftMaxV0Attribute class."""

    axis: int = field(default=1)


[docs] @OPS.register() class SoftMaxV0(Operation[SoftMaxV0Attribute]): """SoftMaxV0 class.""" TYPE = "Softmax" VERSION = "opset1" ATTRIBUTE_FACTORY = SoftMaxV0Attribute
[docs] def forward(self, inputs): """SoftMaxV0's forward function.""" return F.softmax(input=inputs, dim=self.attrs.axis)
@dataclass class SoftMaxV1Attribute(Attribute): """SoftMaxV1Attribute class.""" axis: int = field(default=1)
[docs] @OPS.register() class SoftMaxV1(Operation[SoftMaxV1Attribute]): """SoftMaxV1 class.""" TYPE = "Softmax" VERSION = "opset8" ATTRIBUTE_FACTORY = SoftMaxV1Attribute
[docs] def forward(self, inputs): """SoftMaxV1's forward function.""" return F.softmax(input=inputs, dim=self.attrs.axis)
@dataclass class ReluV0Attribute(Attribute): """ReluV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class ReluV0(Operation[ReluV0Attribute]): """ReluV0 class.""" TYPE = "Relu" VERSION = "opset1" ATTRIBUTE_FACTORY = ReluV0Attribute
[docs] def forward(self, inputs): """ReluV0's forward function.""" return F.relu(inputs)
@dataclass class SwishV4Attribute(Attribute): """SwishV4Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class SwishV4(Operation[SwishV4Attribute]): """SwishV4 class.""" TYPE = "Swish" VERSION = "opset1" ATTRIBUTE_FACTORY = SwishV4Attribute
[docs] def forward(self, inputs, beta=1.0): """SwishV4's forward function.""" return inputs * torch.sigmoid(inputs * beta)
@dataclass class SigmoidV0Attribute(Attribute): """SigmoidV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class SigmoidV0(Operation[SigmoidV0Attribute]): """SigmoidV0 class.""" TYPE = "Sigmoid" VERSION = "opset1" ATTRIBUTE_FACTORY = SigmoidV0Attribute
[docs] def forward(self, inputs): """SigmoidV0's forward function.""" return torch.sigmoid(inputs)
@dataclass class ClampV0Attribute(Attribute): """ClampV0Attribute class.""" min: float max: float
[docs] @OPS.register() class ClampV0(Operation[ClampV0Attribute]): """ClampV0 class.""" TYPE = "Clamp" VERSION = "opset1" ATTRIBUTE_FACTORY = ClampV0Attribute
[docs] def forward(self, inputs): """ClampV0's forward function.""" return inputs.clamp(min=self.attrs.min, max=self.attrs.max)
@dataclass class PReluV0Attribute(Attribute): """PReluV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class PReluV0(Operation[PReluV0Attribute]): """PReluV0 class.""" TYPE = "PRelu" VERSION = "opset1" ATTRIBUTE_FACTORY = PReluV0Attribute
[docs] def forward(self, inputs, slope): """PReluV0's forward function.""" return F.prelu(input=inputs, weight=slope)
@dataclass class TanhV0Attribute(Attribute): """TanhV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class TanhV0(Operation[TanhV0Attribute]): """TanhV0 class.""" TYPE = "Tanh" VERSION = "opset1" ATTRIBUTE_FACTORY = TanhV0Attribute
[docs] def forward(self, inputs): """TanhV0's forward function.""" return F.tanh(inputs)
@dataclass class EluV0Attribute(Attribute): """EluV0Attribute class.""" alpha: float
[docs] @OPS.register() class EluV0(Operation[EluV0Attribute]): """EluV0 class.""" TYPE = "Elu" VERSION = "opset1" ATTRIBUTE_FACTORY = EluV0Attribute
[docs] def forward(self, inputs): """EluV0's forward function.""" return F.elu(input=inputs, alpha=self.attrs.alpha)
@dataclass class SeluV0Attribute(Attribute): """SeluV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class SeluV0(Operation[SeluV0Attribute]): """SeluV0 class.""" TYPE = "Selu" VERSION = "opset1" ATTRIBUTE_FACTORY = SeluV0Attribute
[docs] def forward(self, inputs, alpha, lambda_): """SeluV0's forward function.""" return lambda_ * F.elu(input=inputs, alpha=alpha)
@dataclass class MishV4Attribute(Attribute): """MishV4Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class MishV4(Operation[MishV4Attribute]): """MishV4 class.""" TYPE = "Mish" VERSION = "opset1" ATTRIBUTE_FACTORY = MishV4Attribute
[docs] def forward(self, inputs): """MishV4's forward function.""" # NOTE: pytorch 1.8.2 does not have mish function # return F.mish(input=input) return inputs * F.tanh(F.softplus(inputs))
@dataclass class HSwishV4Attribute(Attribute): """HSwishV4Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class HSwishV4(Operation[HSwishV4Attribute]): """HSwishV4 class.""" TYPE = "HSwish" VERSION = "opset1" ATTRIBUTE_FACTORY = HSwishV4Attribute
[docs] def forward(self, inputs): """HSwishV4's forward function.""" return F.hardswish(input=inputs)
@dataclass class HSigmoidV5Attribute(Attribute): """HSigmoidV5Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class HSigmoidV5(Operation[HSigmoidV5Attribute]): """HSigmoidV5 class.""" TYPE = "HSigmoid" VERSION = "opset1" ATTRIBUTE_FACTORY = HSigmoidV5Attribute
[docs] def forward(self, inputs): """HSigmoidV5's forward function.""" return F.hardsigmoid(input=inputs)
@dataclass class ExpV0Attribute(Attribute): """ExpV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class ExpV0(Operation[ExpV0Attribute]): """ExpV0 class.""" TYPE = "Exp" VERSION = "opset1" ATTRIBUTE_FACTORY = ExpV0Attribute
[docs] def forward(self, inputs): """ExpV0's forward function.""" return torch.exp(inputs)
@dataclass class HardSigmoidV0Attribute(Attribute): """HardSigmoidV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class HardSigmoidV0(Operation[HardSigmoidV0Attribute]): """HardSigmoidV0 class.""" TYPE = "HardSigmoid" VERSION = "opset1" ATTRIBUTE_FACTORY = HardSigmoidV0Attribute
[docs] def forward(self, inputs, alpha, beta): """HardSigmoidV0's forward function.""" return torch.maximum( torch.zeros_like(inputs), torch.minimum(torch.ones_like(inputs), inputs * alpha + beta), )
@dataclass class GeluV7Attribute(Attribute): """GeluV7Attribute class.""" approximation_mode: str = field(default="ERF") def __post_init__(self): """GeluV7Attribute's post init function.""" super().__post_init__() valid_approximation_mode = ["ERF", "tanh"] if self.approximation_mode not in valid_approximation_mode: raise ValueError( f"Invalid approximation_mode {self.approximation_mode}. " f"It must be one of {valid_approximation_mode}." )
[docs] @OPS.register() class GeluV7(Operation[GeluV7Attribute]): """GeluV7 class.""" TYPE = "Gelu" VERSION = "opset1" ATTRIBUTE_FACTORY = GeluV7Attribute
[docs] def forward(self, inputs): """GeluV7's forward function.""" mode = self.attrs.approximation_mode if mode == "ERF": return F.gelu(input=inputs) if mode == "tanh": return ( inputs * 0.5 * (1 + F.tanh(torch.sqrt(2 / torch.tensor(math.pi)) * (inputs + 0.044715 * inputs**3))) ) return None