Source code for otx.core.ov.ops.image_processings

"""Image Processings-related code for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from typing import List

import numpy as np
from torch.nn import functional as F

from .builder import OPS
from .movements import PadV1
from .op import Attribute, Operation

# pylint: disable=too-many-instance-attributes, too-many-branches


@dataclass
class InterpolateV4Attribute(Attribute):
    """InterpolateV4Attribute class."""

    mode: str
    shape_calculation_mode: str
    coordinate_transformation_mode: str = field(default="half_pixel")
    nearest_mode: str = field(default="round_prefer_floor")
    antialias: bool = field(default=False)
    pads_begin: List[int] = field(default_factory=lambda: [0])
    pads_end: List[int] = field(default_factory=lambda: [0])
    cube_coeff: float = field(default=-0.75)

    def __post_init__(self):
        """InterpolateV4Attribute's post-init function."""
        super().__post_init__()
        valid_mode = ["nearest", "linear", "linear_onnx", "cubic"]
        if self.mode not in valid_mode:
            raise ValueError(f"Invalid mode {self.mode}. " f"It must be one of {valid_mode}.")
        valid_shape_calculation_mode = ["sizes", "scales"]
        if self.shape_calculation_mode not in valid_shape_calculation_mode:
            raise ValueError(
                f"Invalid shape_calculation_mode {self.shape_calculation_mode}. "
                f"It must be one of {valid_shape_calculation_mode}."
            )
        valid_coordinate_transformation_mode = [
            "half_pixel",
            "pytorch_half_pixel",
            "asymmetric",
            "tf_half_pixel_for_nn",
            "align_corners",
        ]
        if self.coordinate_transformation_mode not in valid_coordinate_transformation_mode:
            raise ValueError(
                f"Invalid coordinate_transformation_mode {self.coordinate_transformation_mode}. "
                f"It must be one of {valid_coordinate_transformation_mode}."
            )
        valid_nearest_mode = [
            "round_prefer_floor",
            "round_prefer_ceil",
            "floor",
            "ceil",
            "simple",
        ]
        if self.nearest_mode not in valid_nearest_mode:
            raise ValueError(f"Invalid nearest_mode {self.nearest_mode}. " f"It must be one of {valid_nearest_mode}.")


[docs] @OPS.register() class InterpolateV4(Operation[InterpolateV4Attribute]): """InterpolateV4 class.""" TYPE = "Interpolate" VERSION = "opset1" ATTRIBUTE_FACTORY = InterpolateV4Attribute def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pad = PadV1("tmp", shape=self.shape, pad_mode="constant")
[docs] def forward(self, inputs, sizes, scales, axes=None): """InterpolateV4's forward function.""" # TODO list # - handle 'linear_onnx' mode # - coordinate_transformation_mode # - nearest_mode # - cube_coeff # - antialias if axes is None: axes = list(range(inputs.dim())) else: axes = axes.detach().cpu().tolist() output = self.pad(inputs, self.attrs.pads_begin, self.attrs.pads_end, 0) mode = self.attrs.mode if mode in ("linear", "linear_onnx"): align_corners = False if output.dim() == 3: pass elif output.dim() == 4: mode = "bilinear" elif output.dim() == 5: mode = "trilinear" elif mode == "cubic": align_corners = False if output.dim() == 3: raise NotImplementedError if output.dim() == 4: mode = "bicubic" elif output.dim() == 5: raise NotImplementedError elif mode == "nearest": align_corners = None pass # pylint: disable=unnecessary-pass else: raise NotImplementedError if self.attrs.shape_calculation_mode == "sizes": sizes = sizes.detach().cpu().numpy() sizes = sizes[np.argsort(axes)].tolist() if output.dim() == len(sizes): sizes = sizes[2:] return F.interpolate( input=output, size=sizes, scale_factor=None, mode=mode, align_corners=align_corners, ) scales = scales.detach().cpu().numpy() scales = scales[np.argsort(axes)].tolist() if output.dim() == len(scales): scales = scales[2:] return F.interpolate( input=output, size=None, scale_factor=scales, mode=mode, align_corners=align_corners, )