Source code for otx.core.ov.ops.normalizations

"""Normalization-related modules for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field

import torch
from torch.nn import functional as F

from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.op import Attribute, Operation
from otx.core.ov.ops.poolings import AvgPoolV1


@dataclass
class BatchNormalizationV0Attribute(Attribute):
    """BatchNormalizationV0Attribute class."""

    epsilon: float
    max_init_iter: int = field(default=2)


[docs] @OPS.register() class BatchNormalizationV0(Operation[BatchNormalizationV0Attribute]): """BatchNormalizationV0 class.""" TYPE = "BatchNormInference" VERSION = "opset1" ATTRIBUTE_FACTORY = BatchNormalizationV0Attribute def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.register_buffer("_num_init_iter", torch.tensor(0))
[docs] def forward(self, inputs, gamma, beta, mean, variance): """BatchNormalizationV0's forward function.""" output = F.batch_norm( input=inputs, running_mean=mean, running_var=variance, weight=gamma, bias=beta, training=self.training, momentum=0.1, eps=self.attrs.epsilon, ) if self.training and self._num_init_iter < self.attrs.max_init_iter: # no parameters update for adaptive phase with torch.no_grad(): n_dims = inputs.dim() - 2 gamma = gamma.unsqueeze(0) beta = beta.unsqueeze(0) for _ in range(n_dims): gamma = gamma.unsqueeze(-1) beta = beta.unsqueeze(-1) output = inputs * gamma + beta self._num_init_iter += 1 if self._num_init_iter >= self.attrs.max_init_iter: # Adapt weight & bias using the first batch statistics # to undo normalization approximately gamma.data = gamma.data * mean beta.data = beta.data + (mean / (variance + self.attrs.epsilon)) return output
@dataclass class LocalResponseNormalizationV0Attribute(Attribute): """LocalResponseNormalizationV0Attribute class.""" alpha: float beta: float bias: float size: int
[docs] @OPS.register() class LocalResponseNormalizationV0(Operation[LocalResponseNormalizationV0Attribute]): """LocalResponseNormalizationV0 class.""" TYPE = "LRN" VERSION = "opset1" ATTRIBUTE_FACTORY = LocalResponseNormalizationV0Attribute
[docs] def forward(self, inputs, axes): """LocalResponseNormalizationV0's forward function.""" dim = inputs.dim() axes = axes.detach().cpu().tolist() assert all(ax >= 1 for ax in axes) axes = [ax - 1 for ax in axes] kernel = [1 for _ in range(dim - 1)] stride = [1 for _ in range(dim - 1)] pads_begin = [0 for _ in range(dim - 1)] pads_end = [0 for _ in range(dim - 1)] for axe in axes: kernel[axe] = self.attrs.size pads_begin[axe] = self.attrs.size // 2 pads_end[axe] = (self.attrs.size - 1) // 2 avg_attrs = { "auto_pad": "explicit", "strides": stride, "kernel": kernel, "pads_begin": pads_begin, "pads_end": pads_end, "exclude-pad": True, "shape": self.shape, } avg_pool = AvgPoolV1("temp", **avg_attrs) div = inputs.mul(inputs).unsqueeze(1) div = avg_pool(div) div = div.squeeze(1) div = div.mul(self.attrs.alpha).add(self.attrs.bias).pow(self.attrs.beta) output = inputs / div return output
@dataclass class NormalizeL2V0Attribute(Attribute): """NormalizeL2V0Attribute class.""" eps: float eps_mode: str def __post_init__(self): """NormalizeL2V0Attribute post-init function.""" super().__post_init__() valid_eps_mode = ["add", "max"] if self.eps_mode not in valid_eps_mode: raise ValueError(f"Invalid eps_mode {self.eps_mode}. " f"It must be one of {valid_eps_mode}.")
[docs] @OPS.register() class NormalizeL2V0(Operation[NormalizeL2V0Attribute]): """NormalizeL2V0 class.""" TYPE = "NormalizeL2" VERSION = "opset1" ATTRIBUTE_FACTORY = NormalizeL2V0Attribute
[docs] def forward(self, inputs, axes): """NormalizeL2V0's forward function.""" eps = self.attrs.eps eps_mode = self.attrs.eps_mode if isinstance(axes, torch.Tensor): axes = axes.detach().cpu().tolist() if not isinstance(axes, (list, tuple)): axes = [axes] # normalization layer convert to FP32 in FP16 training input_float = inputs.float() if axes: norm = input_float.pow(2).sum(axes, keepdim=True) else: norm = input_float if eps_mode == "add": norm = norm + eps elif eps_mode == "max": norm = torch.clamp(norm, max=eps) return (input_float / norm.sqrt()).type_as(inputs)
@dataclass class MVNV6Attribute(Attribute): """MVNV6Attribute class.""" normalize_variance: bool eps: float eps_mode: str def __post_init__(self): """MVNV6Attribute's post-init function.""" super().__post_init__() valid_eps_mode = ["INSIDE_SQRT", "OUTSIDE_SQRT"] if self.eps_mode not in valid_eps_mode: raise ValueError(f"Invalid eps_mode {self.eps_mode}. " f"It must be one of {valid_eps_mode}.")
[docs] @OPS.register() class MVNV6(Operation[MVNV6Attribute]): """MVNV6 class.""" TYPE = "MVN" VERSION = "opset1" ATTRIBUTE_FACTORY = MVNV6Attribute
[docs] def forward(self, inputs, axes): """MVNV6's forward function.""" output = inputs - inputs.mean(axes.tolist(), keepdim=True) if self.attrs.normalize_variance: eps_mode = self.attrs.eps_mode eps = self.attrs.eps var = torch.square(output).mean(axes.tolist(), keepdim=True) if eps_mode == "INSIDE_SQRT": output = output / torch.sqrt(var + eps) elif eps_mode == "OUTSIDE_SQRT": output = output / (torch.sqrt(var) + eps) return output