"""Movement-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass, field
from functools import partial
from typing import List
import torch
from torch.nn import functional as F
from .builder import OPS
from .op import Attribute, Operation
# pylint: disable=too-many-branches
@dataclass
class PadV1Attribute(Attribute):
"""PadV1Attribute class."""
pad_mode: str
def __post_init__(self):
"""PadV1Attribute's post-init function."""
super().__post_init__()
valid_pad_mode = ["constant", "edge", "reflect", "symmetric"]
if self.pad_mode not in valid_pad_mode:
raise ValueError(f"Invalid pad_mode {self.pad_mode}. " f"It must be one of {valid_pad_mode}.")
[docs]
@OPS.register()
class PadV1(Operation[PadV1Attribute]):
"""PadV1 class."""
TYPE = "Pad"
VERSION = "opset1"
ATTRIBUTE_FACTORY = PadV1Attribute
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pad_mode = self.get_torch_pad_mode(self.attrs.pad_mode)
[docs]
@staticmethod
def get_torch_pad_mode(pad_mode):
"""PadV1's get_torch_pad_mode function."""
if pad_mode == "constant":
return "constant"
if pad_mode == "edge":
return "replicate"
if pad_mode == "reflect":
return "reflect"
raise NotImplementedError
[docs]
@staticmethod
def get_torch_pad_dim(pads_begin, pads_end):
"""PadV1's get_torch_pad_dim function."""
# reverse padding
return [val for tup in zip(pads_begin[::-1], pads_end[::-1]) for val in tup]
[docs]
def forward(self, inputs, pads_begin, pads_end, pad_value=0):
"""PadV1's forward function."""
pads_begin = pads_begin if isinstance(pads_begin, list) else pads_begin.detach().cpu().tolist()
pads_end = pads_end if isinstance(pads_end, list) else pads_end.detach().cpu().tolist()
pad = self.get_torch_pad_dim(pads_begin, pads_end)
pad = list(map(math.ceil, pad))
return F.pad(input=inputs, pad=pad, mode=self._pad_mode, value=pad_value)
@dataclass
class ConcatV0Attribute(Attribute):
"""ConcatV0Attribute class."""
axis: int
[docs]
@OPS.register()
class ConcatV0(Operation[ConcatV0Attribute]):
"""ConcatV0 class."""
TYPE = "Concat"
VERSION = "opset1"
ATTRIBUTE_FACTORY = ConcatV0Attribute
[docs]
def forward(self, *inputs):
"""ConcatV0's forward function."""
return torch.cat(inputs, self.attrs.axis)
@dataclass
class TransposeV1Attribute(Attribute):
"""TransposeV1Attribute class."""
pass # pylint: disable=unnecessary-pass
[docs]
@OPS.register()
class TransposeV1(Operation[TransposeV1Attribute]):
"""TransposeV1 class."""
TYPE = "Transpose"
VERSION = "opset1"
ATTRIBUTE_FACTORY = TransposeV1Attribute
[docs]
def forward(self, inputs, order):
"""TransposeV1's forward function."""
if order.numel() == 0:
order = list(range(inputs.dim()))[::-1]
elif isinstance(order, torch.Tensor):
order = order.detach().cpu().tolist()
return inputs.permute(order)
@dataclass
class GatherV0Attribute(Attribute):
"""GatherV0Attribute class."""
batch_dims: int = field(default=0)
[docs]
@OPS.register()
class GatherV0(Operation[GatherV0Attribute]):
"""GatherV0 class."""
TYPE = "Gather"
VERSION = "opset1"
ATTRIBUTE_FACTORY = GatherV0Attribute
[docs]
def forward(self, inputs, indices, axis):
"""GatherV0's forward function."""
assert axis.numel() == 1
axis = axis.squeeze()
squeeze_axis = indices.dim() == 0
batch_dims = self.attrs.batch_dims
if batch_dims < 0:
batch_dims = indices.dim() + batch_dims
indices_shape = torch.tensor(indices.shape)
if batch_dims < axis:
indices = indices.reshape(*indices_shape[:batch_dims], -1)
indices_shape = indices_shape[batch_dims:]
if indices.dim() != inputs.dim():
if indices.dim() != 0:
while indices.dim() - 1 < axis:
indices = indices.unsqueeze(batch_dims)
while indices.dim() < inputs.dim():
indices = indices.unsqueeze(-1)
repeat = []
for i, (j, k) in enumerate(zip(inputs.shape, indices.shape)):
if i == axis:
repeat.append(1)
else:
assert j % k == 0
repeat.append(j // k)
indices = indices.repeat(repeat)
output = torch.gather(input=inputs, dim=axis, index=indices.type(torch.int64))
if squeeze_axis:
output = output.squeeze(axis)
return output
@dataclass
class GatherV1Attribute(Attribute):
"""GatherV1Attribute class."""
pass # pylint: disable=unnecessary-pass
[docs]
@OPS.register()
class GatherV1(Operation[GatherV1Attribute]):
"""GatherV1 class."""
TYPE = "Gather"
VERSION = "opset2"
ATTRIBUTE_FACTORY = GatherV1Attribute
[docs]
def forward(self, inputs, indices, axis):
"""GatherV1's forward function."""
return torch.gather(input=inputs, dim=axis, index=indices)
@dataclass
class StridedSliceV1Attribute(Attribute):
"""StridedSliceV1Attribute class."""
begin_mask: List[int]
end_mask: List[int]
new_axis_mask: List[int] = field(default_factory=lambda: [0])
shrink_axis_mask: List[int] = field(default_factory=lambda: [0])
ellipsis_mask: List[int] = field(default_factory=lambda: [0])
[docs]
@OPS.register()
class StridedSliceV1(Operation[StridedSliceV1Attribute]):
"""StridedSliceV1 class."""
TYPE = "StridedSlice"
VERSION = "opset1"
ATTRIBUTE_FACTORY = StridedSliceV1Attribute
[docs]
def forward(self, inputs, begin, end, stride=None):
"""StridedSliceV1's forward function."""
if sum(self.attrs.ellipsis_mask) > 0:
raise NotImplementedError
for i, mask in enumerate(self.attrs.begin_mask):
if mask == 1:
begin[i] = 0
for i, mask in enumerate(self.attrs.end_mask):
if mask == 1:
end[i] = inputs.size(i)
if stride is None:
stride = torch.tensor([1 for _ in begin], dtype=begin.dtype)
output = inputs
for i, (b, e, stride_0) in enumerate(zip(begin, end, stride)):
length = inputs.size(i)
# begin index is inclusive
b = torch.clamp(b, -length, length - 1)
# end index is exclusive
e = torch.clamp(e, -length - 1, length)
if stride_0 > 0:
b = b + length if b < 0 else b
e = e + length if e < 0 else e
indices = torch.arange(b, e, stride_0, device=inputs.device)
else:
b = b - length if b >= 0 else b
e = e - length if e >= 0 else e
indices = torch.arange(b, e, stride_0, device=inputs.device)
indices += length
output = torch.index_select(output, i, indices)
for i, mask in enumerate(self.attrs.new_axis_mask[::-1]):
if mask == 1:
i = abs(i - len(self.attrs.new_axis_mask) + 1)
output = output.unsqueeze(i)
for i, mask in enumerate(self.attrs.shrink_axis_mask[::-1]):
if mask == 1:
i = abs(i - len(self.attrs.new_axis_mask) + 1)
if output.size(i) != 1:
raise NotImplementedError
output = output.squeeze(i)
return output
@dataclass
class SplitV1Attribute(Attribute):
"""SplitV1Attribute class."""
num_splits: int
[docs]
@OPS.register()
class SplitV1(Operation[SplitV1Attribute]):
"""SplitV1 class."""
TYPE = "Split"
VERSION = "opset1"
ATTRIBUTE_FACTORY = SplitV1Attribute
[docs]
def forward(self, inputs, axis):
"""SplitV1's forward function."""
split_size = inputs.shape[axis] // self.attrs.num_splits
return torch.split(tensor=inputs, split_size_or_sections=split_size, dim=axis)
@dataclass
class VariadicSplitV1Attribute(Attribute):
"""VariadicSplitV1Attribute class."""
pass # pylint: disable=unnecessary-pass
[docs]
@OPS.register()
class VariadicSplitV1(Operation[VariadicSplitV1Attribute]):
"""VariadicSplitV1 class."""
TYPE = "VariadicSplit"
VERSION = "opset1"
ATTRIBUTE_FACTORY = VariadicSplitV1Attribute
[docs]
def forward(self, inputs, axis, split_lengths):
"""VariadicSplitV1's forward function."""
idx = [i for i, j in enumerate(split_lengths) if j == -1]
if idx:
assert len(idx) == 1
idx = idx[0]
split_lengths[idx] = inputs.size(axis) - sum(split_lengths) - 1
assert inputs.size(axis) == sum(split_lengths)
outputs = []
start_idx = 0
for length in split_lengths:
outputs.append(
torch.index_select(
inputs,
axis,
torch.arange(start_idx, start_idx + length, device=inputs.device),
)
)
start_idx += length
return tuple(outputs)
@dataclass
class ShuffleChannelsV0Attribute(Attribute):
"""ShuffleChannelsV0Attribute class."""
axis: int = field(default=1)
group: int = field(default=1)
[docs]
@OPS.register()
class ShuffleChannelsV0(Operation[ShuffleChannelsV0Attribute]):
"""ShuffleChannelsV0 class."""
TYPE = "ShuffleChannels"
VERSION = "opset1"
ATTRIBUTE_FACTORY = ShuffleChannelsV0Attribute
[docs]
def forward(self, inputs):
"""ShuffleChannelsV0's forward function."""
# n, c, h, w = input.shape
assert inputs.dim() == 4
origin_shape = inputs.shape
origin_dim = inputs.dim()
assert origin_shape[self.attrs.axis] % self.attrs.group == 0
axis = self.attrs.axis
axis = axis if axis >= 0 else axis + inputs.dim()
target_shape = [
0,
self.attrs.group,
int(origin_shape[axis] / self.attrs.group),
0,
]
if axis == 0:
target_shape[0] = 1
target_shape[-1] = math.prod([origin_shape[i] for i in range(axis + 1, origin_dim)])
elif axis == inputs.dim() - 1:
target_shape[0] = math.prod([origin_shape[i] for i in range(0, axis)])
target_shape[-1] = 1
else:
target_shape[0] = math.prod([origin_shape[i] for i in range(0, axis)])
target_shape[-1] = math.prod([origin_shape[i] for i in range(axis + 1, origin_dim)])
output = inputs.reshape(target_shape)
output = output.permute([0, 2, 1, 3])
output = output.reshape(origin_shape)
return output
@dataclass
class BroadcastV3Attribute(Attribute):
"""BroadcastV3Attribute class."""
mode: str = field(default="numpy")
def __post_init__(self):
"""BroadcastV3Attribute's post-init function."""
super().__post_init__()
valid_mode = ["numpy", "explicit", "bidirectional"]
if self.mode not in valid_mode:
raise ValueError(f"Invalid mode {self.mode}. " f"It must be one of {valid_mode}.")
[docs]
@OPS.register()
class BroadcastV3(Operation[BroadcastV3Attribute]):
"""BroadcastV3 class."""
TYPE = "Broadcast"
VERSION = "opset1"
ATTRIBUTE_FACTORY = BroadcastV3Attribute
[docs]
def forward(self, inputs, target_shape, axes_mapping=None):
"""BroadcastV3's forward function."""
if self.attrs.mode == "numpy":
return inputs.expand(*target_shape)
if self.attrs.mode == "bidirectional":
return torch.ones(*target_shape, device=inputs.device) * inputs
assert axes_mapping is not None
prev = -1
for axes in axes_mapping:
prev += 1
while axes - prev > 0:
inputs = inputs.unsqueeze(axes - 1)
prev += 1
while inputs.dim() < len(target_shape):
inputs = inputs.unsqueeze(-1)
return inputs.expand(*target_shape)
@dataclass
class ScatterNDUpdateV3Attribute(Attribute):
"""ScatterNDUpdateV3Attribute class."""
pass # pylint: disable=unnecessary-pass
[docs]
@OPS.register()
class ScatterNDUpdateV3(Operation[ScatterNDUpdateV3Attribute]):
"""ScatterNDUpdateV3 class."""
TYPE = "ScatterNDUpdate"
VERSION = "opset1"
ATTRIBUTE_FACTORY = ScatterNDUpdateV3Attribute
[docs]
def forward(self, inputs, indicies, updates):
"""ScatterNDUpdateV3's forward function."""
# TODO: need to verify
if updates.numel() == 1:
raise NotImplementedError
# FIXME: hard-coded
last_dim = indicies.shape[-1]
assert last_dim == 2
assert indicies[..., -2].sum() == 0
inputs.shape[indicies.shape[-1] :] # pylint: disable=pointless-statement
index = indicies[..., -1]
for i in inputs.shape[indicies.shape[-1] :]:
index = index.unsqueeze(-1).tile((i,))
output = torch.scatter(inputs, 1, index, updates)
return output
@dataclass
class ScatterUpdateV3Attribute(Attribute):
"""ScatterUpdateV3Attribute class."""
pass # pylint: disable=unnecessary-pass
[docs]
@OPS.register()
class ScatterUpdateV3(Operation[ScatterUpdateV3Attribute]):
"""ScatterUpdateV3 class."""
TYPE = "ScatterUpdate"
VERSION = "opset1"
ATTRIBUTE_FACTORY = ScatterUpdateV3Attribute
[docs]
def forward(self, inputs, indicies, updates, axis):
"""ScatterUpdateV3's forward function."""
# TODO: need to verify
axis = axis.item()
if inputs.dtype != updates.dtype:
updates = updates.type(inputs.dtype)
if indicies.dim() == 0:
assert axis == 0
output = inputs
output[indicies] = updates
output = torch.scatter(inputs, axis, indicies, updates)
return output
@dataclass
class TileV0Attribute(Attribute):
"""TileV0Attribute class."""
pass # pylint: disable=unnecessary-pass
[docs]
@OPS.register()
class TileV0(Operation[TileV0Attribute]):
"""TileV0 class."""
TYPE = "Tile"
VERSION = "opset1"
ATTRIBUTE_FACTORY = TileV0Attribute
[docs]
def forward(self, inputs, repeats):
"""TileV0's forward function."""
return torch.tile(inputs, repeats.tolist())
def get_torch_padding(pads_begin, pads_end, auto_pad, input_size, weight_size, stride, dilation=None):
"""Getter function for torch padding."""
if dilation is None:
dilation = [1 for _ in input_size]
if auto_pad == "valid":
return 0
if auto_pad in ("same_upper", "same_lower"):
assert len(set(dilation)) == 1 and dilation[0] == 1
pads_begin = []
pads_end = []
for input_size_, weight_size_, stride_, _ in zip(input_size, weight_size, stride, dilation):
out_size = math.ceil(input_size_ / stride_)
padding_needed = max(0, (out_size - 1) * stride_ + weight_size_ - input_size_)
padding_lhs = int(padding_needed / 2)
padding_rhs = padding_needed - padding_lhs
pads_begin.append(padding_lhs if auto_pad == "same_upper" else padding_rhs)
pads_end.append(padding_rhs if auto_pad == "same_upper" else padding_lhs)
pad = PadV1.get_torch_pad_dim(pads_begin, pads_end)
return partial(F.pad, pad=pad, mode="constant", value=0)
if auto_pad == "explicit":
pad = PadV1.get_torch_pad_dim(pads_begin, pads_end)
return partial(F.pad, pad=pad, mode="constant", value=0)
raise NotImplementedError