Source code for otx.core.ov.ops.convolutions

"""Convolutions-related module for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from typing import Callable, List

from torch.nn import functional as F

from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.movements import get_torch_padding
from otx.core.ov.ops.op import Attribute, Operation

class ConvolutionV1Attribute(Attribute):
    """ConvolutionV1Attribute class."""

    strides: List[int]
    pads_begin: List[int]
    pads_end: List[int]
    dilations: List[int]
    auto_pad: str = field(default="explicit")

    def __post_init__(self):
        """ConvolutionV1Attribute's post-init function."""
        valid_auto_pad = ["explicit", "same_upper", "same_Lower", "valid"]
        if self.auto_pad not in valid_auto_pad:
            raise ValueError(f"Invalid auto_pad {self.auto_pad}. " f"It must be one of {valid_auto_pad}.")

[docs] @OPS.register() class ConvolutionV1(Operation[ConvolutionV1Attribute]): """ConvolutionV1 class.""" TYPE = "Convolution" VERSION = "opset1" ATTRIBUTE_FACTORY = ConvolutionV1Attribute
[docs] def forward(self, inputs, weight): """ConvolutionV1's forward function.""" if weight.dim() == 3: func = F.conv1d elif weight.dim() == 4: func = F.conv2d elif weight.dim() == 5: func = F.conv3d else: raise NotImplementedError padding = get_torch_padding( self.attrs.pads_begin, self.attrs.pads_end, self.attrs.auto_pad, list(inputs.shape[2:]), list(weight.shape[2:]), self.attrs.strides, self.attrs.dilations, ) if isinstance(padding, Callable): inputs = padding(input=inputs) padding = 0 return func( input=inputs, weight=weight, bias=None, stride=self.attrs.strides, padding=padding, dilation=self.attrs.dilations, )
@dataclass class GroupConvolutionV1Attribute(ConvolutionV1Attribute): """GroupConvolutionV1Attribute class.""" pass # pylint: disable=unnecessary-pass
[docs] @OPS.register() class GroupConvolutionV1(Operation[GroupConvolutionV1Attribute]): """GroupConvolutionV1 class.""" TYPE = "GroupConvolution" VERSION = "opset1" ATTRIBUTE_FACTORY = GroupConvolutionV1Attribute
[docs] def forward(self, inputs, weight): """GroupConvolutionV1's forward function.""" if weight.dim() == 4: func = F.conv1d elif weight.dim() == 5: func = F.conv2d elif weight.dim() == 6: func = F.conv3d else: raise NotImplementedError n_groups = weight.shape[0] # merge groups and out dimension weight = weight.view(-1, *weight.shape[2:]) padding = get_torch_padding( self.attrs.pads_begin, self.attrs.pads_end, self.attrs.auto_pad, list(inputs.shape[2:]), list(weight.shape[2:]), self.attrs.strides, self.attrs.dilations, ) if isinstance(padding, Callable): inputs = padding(input=inputs) padding = 0 output = func( input=inputs, weight=weight, bias=None, stride=self.attrs.strides, padding=padding, dilation=self.attrs.dilations, groups=n_groups, ) return output