Source code for otx.core.ov.ops.poolings
"""Pooling-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Callable, List
from torch.nn import functional as F
from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.movements import get_torch_padding
from otx.core.ov.ops.op import Attribute, Operation
# pylint: disable=too-many-instance-attributes
@dataclass
class MaxPoolV0Attribute(Attribute):
"""MaxPoolV0Attribute class."""
strides: List[int]
pads_begin: List[int]
pads_end: List[int]
kernel: List[int]
rounding_type: str = field(default="floor")
auto_pad: str = field(default="explicit")
dilations: List[int] = field(default_factory=lambda: [])
index_element_type: str = field(default="i64")
axis: int = field(default=0)
def __post_init__(self):
"""MaxPoolV0Attribute's post-init functions."""
super().__post_init__()
valid_auto_pad = ["explicit", "same_upper", "same_Lower", "valid"]
if self.auto_pad not in valid_auto_pad:
raise ValueError(f"Invalid auto_pad {self.auto_pad}. " f"It must be one of {valid_auto_pad}.")
valid_rounding_type = ["ceil", "floor"]
if self.rounding_type not in valid_rounding_type:
raise ValueError(
f"Invalid rounding_type {self.rounding_type}. " f"It must be one of {valid_rounding_type}."
)
valid_index_element_type = ["i32", "i64"]
if self.index_element_type not in valid_index_element_type:
raise ValueError(
f"Invalid index_element_type {self.index_element_type}. "
f"It must be one of {valid_index_element_type}."
)
if not self.dilations:
self.dilations = [1 for _ in self.strides]
if self.axis != 0:
raise NotImplementedError
[docs]
@OPS.register()
class MaxPoolV0(Operation[MaxPoolV0Attribute]):
"""MaxPoolV0 class."""
TYPE = "MaxPool"
VERSION = "opset8"
ATTRIBUTE_FACTORY = MaxPoolV0Attribute
[docs]
def forward(self, inputs):
"""MaxPoolV0's forward function."""
if inputs.dim() == 3:
func = F.max_pool1d
elif inputs.dim() == 4:
func = F.max_pool2d
elif inputs.dim() == 5:
func = F.max_pool3d
else:
raise NotImplementedError
padding = get_torch_padding(
self.attrs.pads_begin,
self.attrs.pads_end,
self.attrs.auto_pad,
list(inputs.shape[2:]),
self.attrs.kernel,
self.attrs.strides,
)
if isinstance(padding, Callable):
inputs = padding(input=inputs)
padding = 0
return func(
input=inputs,
kernel_size=self.attrs.kernel,
stride=self.attrs.strides,
padding=padding,
dilation=self.attrs.dilations,
ceil_mode=self.attrs.rounding_type == "ceil",
return_indices=True,
)
@dataclass
class AvgPoolV1Attribute(Attribute):
"""AvgPoolV1Attribute class."""
exclude_pad: bool
strides: List[int]
pads_begin: List[int]
pads_end: List[int]
kernel: List[int]
rounding_type: str = field(default="floor")
auto_pad: str = field(default="explicit")
def __post_init__(self):
"""AvgPoolV1Attribute's post-init function."""
super().__post_init__()
valid_auto_pad = ["explicit", "same_upper", "same_Lower", "valid"]
if self.auto_pad not in valid_auto_pad:
raise ValueError(f"Invalid auto_pad {self.auto_pad}. " f"It must be one of {valid_auto_pad}.")
valid_rounding_type = ["ceil", "floor"]
if self.rounding_type not in valid_rounding_type:
raise ValueError(
f"Invalid rounding_type {self.rounding_type}. " f"It must be one of {valid_rounding_type}."
)
[docs]
@OPS.register()
class AvgPoolV1(Operation[AvgPoolV1Attribute]):
"""AvgPoolV1 class."""
TYPE = "AvgPool"
VERSION = "opset1"
ATTRIBUTE_FACTORY = AvgPoolV1Attribute
def __init__(self, *args, **kwargs):
if "exclude-pad" in kwargs:
kwargs["exclude_pad"] = kwargs.pop("exclude-pad")
super().__init__(*args, **kwargs)
[docs]
def forward(self, inputs):
"""AvgPoolV1's forward function."""
if inputs.dim() == 3:
func = F.avg_pool1d
elif inputs.dim() == 4:
func = F.avg_pool2d
elif inputs.dim() == 5:
func = F.avg_pool3d
else:
raise NotImplementedError
padding = get_torch_padding(
self.attrs.pads_begin,
self.attrs.pads_end,
self.attrs.auto_pad,
list(inputs.shape[2:]),
self.attrs.kernel,
self.attrs.strides,
)
if isinstance(padding, Callable):
inputs = padding(input=inputs)
padding = 0
return func(
input=inputs,
kernel_size=self.attrs.kernel,
stride=self.attrs.strides,
padding=padding,
ceil_mode=self.attrs.rounding_type == "ceil",
count_include_pad=not self.attrs.exclude_pad,
)