Source code for otx.core.ov.ops.reductions

"""Redunction-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field

import torch

from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.op import Attribute, Operation


@dataclass
class ReduceMeanV1Attribute(Attribute):
    """ReduceMeanV1Attribute class."""

    keep_dims: bool = field(default=False)


[docs] @OPS.register() class ReduceMeanV1(Operation[ReduceMeanV1Attribute]): """ReduceMeanV1 class.""" TYPE = "ReduceMean" VERSION = "opset1" ATTRIBUTE_FACTORY = ReduceMeanV1Attribute
[docs] def forward(self, inputs, axes): """ReduceMeanV1's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: return inputs if not isinstance(axes, (list, tuple)): axes = [axes] return torch.mean(input=inputs, dim=axes, keepdim=self.attrs.keep_dims)
@dataclass class ReduceProdV1Attribute(Attribute): """ReduceMeanV1Attribute class.""" keep_dims: bool = field(default=False)
[docs] @OPS.register() class ReduceProdV1(Operation[ReduceProdV1Attribute]): """ReduceMeanV1Attribute class.""" TYPE = "ReduceProd" VERSION = "opset1" ATTRIBUTE_FACTORY = ReduceProdV1Attribute
[docs] def forward(self, inputs, axes): """ReduceMeanV1Attribute's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: return inputs if not isinstance(axes, (list, tuple)): axes = [axes] output = inputs for axe in axes: output = torch.prod(input=output, dim=axe, keepdim=True) if not self.attrs.keep_dims: output = torch.squeeze(output) return output
@dataclass class ReduceMinV1Attribute(Attribute): """ReduceMinV1Attribute class.""" keep_dims: bool = field(default=False)
[docs] @OPS.register() class ReduceMinV1(Operation[ReduceMinV1Attribute]): """ReduceMinV1 class.""" TYPE = "ReduceMin" VERSION = "opset1" ATTRIBUTE_FACTORY = ReduceMinV1Attribute
[docs] def forward(self, inputs, axes): """ReduceMinV1's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: return inputs if not isinstance(axes, (list, tuple)): axes = [axes] output = inputs for axe in axes: output = torch.min(input=output, dim=axe, keepdim=True)[0] if not self.attrs.keep_dims: output = torch.squeeze(output) return output
@dataclass class ReduceSumV1Attribute(Attribute): """ReduceSumV1Attribute class.""" keep_dims: bool = field(default=False)
[docs] @OPS.register() class ReduceSumV1(Operation[ReduceSumV1Attribute]): """ReduceSumV1 class.""" TYPE = "ReduceSum" VERSION = "opset1" ATTRIBUTE_FACTORY = ReduceSumV1Attribute
[docs] def forward(self, inputs, axes): """ReduceSumV1's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: return inputs return torch.sum(input=inputs, dim=axes, keepdim=self.attrs.keep_dims)