Source code for otx.core.ov.ops.shape_manipulations

"""Shape-mainpulation-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field

import torch

from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.op import Attribute, Operation
from otx.core.ov.ops.type_conversions import ConvertV0


@dataclass
class SqueezeV0Attribute(Attribute):
    """SqueezeV0Attribute class."""

    pass  # pylint: disable=unnecessary-pass


[docs] @OPS.register() class SqueezeV0(Operation[SqueezeV0Attribute]): """SqueezeV0 class.""" TYPE = "Squeeze" VERSION = "opset1" ATTRIBUTE_FACTORY = SqueezeV0Attribute
[docs] def forward(self, inputs, dims=None): """SqueezeV0's forward function.""" if dims is None: return torch.squeeze(inputs) if dims.dim() == 0: dims = torch.unsqueeze(dims, 0) max_dim = inputs.dim() dims = dims.detach().cpu().tolist() for i, dim in enumerate(dims): if dim < 0: dims[i] = max_dim + dim output = inputs for dim in sorted(dims, reverse=True): output = torch.squeeze(output, dim) return output
@dataclass class UnsqueezeV0Attribute(Attribute): """UnsqueezeV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class UnsqueezeV0(Operation[UnsqueezeV0Attribute]): """UnsqueezeV0 class.""" TYPE = "Unsqueeze" VERSION = "opset1" ATTRIBUTE_FACTORY = UnsqueezeV0Attribute
[docs] def forward(self, inputs, dims): """UnsqueezeV0's forward function.""" if dims.dim() == 0: dims = torch.unsqueeze(dims, 0) max_dim = inputs.dim() dims = dims.detach().cpu().tolist() if len(dims) > 1: for i, dim in enumerate(dims): if dim < 0: dims[i] = max_dim + dim output = inputs for dim in sorted(dims, reverse=True): output = torch.unsqueeze(output, dim) return output
@dataclass class ReshapeV1Attribute(Attribute): """ReshapeV1Attribute class.""" special_zero: bool
[docs] @OPS.register() class ReshapeV1(Operation[ReshapeV1Attribute]): """ReshapeV1 class.""" TYPE = "Reshape" VERSION = "opset1" ATTRIBUTE_FACTORY = ReshapeV1Attribute
[docs] def forward(self, inputs, shape): """ReshapeV1's forward function.""" target_shape = shape.detach().cpu().tolist() origin_shape = list(inputs.shape) for i, (origin_dim, target_dim) in enumerate(zip(origin_shape, target_shape)): if target_dim == 0 and self.attrs.special_zero: target_shape[i] = origin_dim elif target_dim == -1: break for i, (origin_dim, target_dim) in enumerate(zip(origin_shape[::-1], target_shape[::-1])): if target_dim == 0 and self.attrs.special_zero: target_shape[i] = origin_dim elif target_dim == -1: break return torch.reshape(inputs, target_shape)
@dataclass class ShapeOfV0Attribute(Attribute): """ShapeOfV0Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class ShapeOfV0(Operation[ShapeOfV0Attribute]): """ShapeOfV0 class.""" TYPE = "ShapeOf" VERSION = "opset1" ATTRIBUTE_FACTORY = ShapeOfV0Attribute
[docs] def forward(self, inputs): """ShapeOfV0's forward function.""" return torch.tensor(inputs.shape, device=inputs.device)
@dataclass class ShapeOfV3Attribute(Attribute): """ShapeOfV3Attribute class.""" output_type: str = field(default="i64") def __post_init__(self): """ShapeOfV3Attribute's post-init function.""" super().__post_init__() valid_output_type = ["i64", "i32"] if self.output_type not in valid_output_type: raise ValueError(f"Invalid output_type {self.output_type}. " f"It must be one of {valid_output_type}.")
[docs] @OPS.register() class ShapeOfV3(Operation[ShapeOfV3Attribute]): """ShapeOfV3 class.""" TYPE = "ShapeOf" VERSION = "opset3" ATTRIBUTE_FACTORY = ShapeOfV3Attribute
[docs] def forward(self, inputs): """ShapeOfV3's forward function.""" return ConvertV0("temp", shape=self.shape, destination_type=self.attrs.output_type)( torch.tensor(inputs.shape, device=inputs.device) )