Source code for otx.algo.classification.backbones.efficientnet

# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""EfficientNet Module."""

from __future__ import annotations

import math
from pathlib import Path
from typing import Any, Callable, ClassVar, Literal

import torch
from pytorchcv.models.model_store import download_model
from torch import nn
from torch.nn import functional, init

from otx.algo.modules.activation import Swish, build_activation_layer
from otx.algo.modules.conv_module import Conv2dModule
from otx.algo.modules.norm import build_norm_layer

PRETRAINED_ROOT = "https://github.com/osmr/imgclsmob/releases/download/v0.0.364/"
pretrained_urls = {
    "efficientnet_b0": PRETRAINED_ROOT + "efficientnet_b0-0752-0e386130.pth.zip",
}


def conv1x1_block(
    in_channels: int,
    out_channels: int,
    stride: int | tuple[int, int] = 1,
    padding: int | tuple[int, int] = 0,
    groups: int = 1,
    bias: bool = False,
    use_bn: bool = True,
    bn_eps: float = 1e-5,
    activation: Callable[..., nn.Module] | None = nn.ReLU,
) -> Conv2dModule:
    """Conv block."""
    return Conv2dModule(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=1,
        stride=stride,
        padding=padding,
        groups=groups,
        bias=bias,
        normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
        activation=build_activation_layer(activation),
    )


def conv3x3_block(
    in_channels: int,
    out_channels: int,
    stride: int | tuple[int, int] = 1,
    padding: int | tuple[int, int] = 1,
    dilation: int = 1,
    groups: int = 1,
    bias: bool = False,
    use_bn: bool = True,
    bn_eps: float = 1e-5,
    activation: Callable[..., nn.Module] | None = nn.ReLU,
) -> Conv2dModule:
    """Conv block."""
    return Conv2dModule(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
        bias=bias,
        normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
        activation=build_activation_layer(activation),
    )


def dwconv3x3_block(
    in_channels: int,
    out_channels: int,
    stride: int | tuple[int, int] = 1,
    padding: int | tuple[int, int] = 1,
    dilation: int = 1,
    bias: bool = False,
    use_bn: bool = True,
    bn_eps: float = 1e-5,
    activation: Callable[..., nn.Module] | None = nn.ReLU,
) -> Conv2dModule:
    """Conv block."""
    return Conv2dModule(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=out_channels,
        bias=bias,
        normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
        activation=build_activation_layer(activation),
    )


def dwconv5x5_block(
    in_channels: int,
    out_channels: int,
    stride: int | tuple[int, int] = 1,
    padding: int | tuple[int, int] = 2,
    dilation: int = 1,
    bias: bool = False,
    use_bn: bool = True,
    bn_eps: float = 1e-5,
    activation: Callable[..., nn.Module] | None = nn.ReLU,
) -> Conv2dModule:
    """Conv block."""
    return Conv2dModule(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=5,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=out_channels,
        bias=bias,
        normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
        activation=build_activation_layer(activation),
    )


def round_channels(channels: float, divisor: int = 8) -> int:
    """Round weighted channel number (make divisible operation).

    Args:
        channels : int or float. Original number of channels.
        divisor : int, default 8. Alignment value.
    """
    rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor)
    if float(rounded_channels) < 0.9 * channels:
        rounded_channels += divisor
    return rounded_channels


def calc_tf_padding(x: torch.Tensor, kernel_size: int, stride: int | tuple[int, int] = 1, dilation: int = 1) -> tuple:
    """Calculate TF-same like padding size.

    Args:
        x : tensor. Input tensor.
        kernel_size : int. Convolution window size.
        stride : int, default 1. Strides of the convolution.
        dilation : int, default 1. Dilation value for convolution layer.
    """
    height, width = x.size()[2:]
    oh = math.ceil(height / stride)
    ow = math.ceil(width / stride)
    pad_h = max((oh - 1) * stride + (kernel_size - 1) * dilation + 1 - height, 0)
    pad_w = max((ow - 1) * stride + (kernel_size - 1) * dilation + 1 - width, 0)
    return pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2


class SEBlock(nn.Module):
    """Squeeze-and-Excitation block from 'Squeeze-and-Excitation Networks,'.

    https://arxiv.org/abs/1709.01507.

    Args:
        channels (int): Number of channels.
        reduction (int): Squeeze reduction value. Default to 16.
        mid_channels (int | None): Number of middle channels. Defaults to None.
        round_mid (bool): Whether to round middle channel number (make divisible by 8). Defaults to False.
        use_conv (bool): Whether to convolutional layers instead of fully-connected ones. Defaults to True.
        mid_activation (Callable[..., nn.Module]): Activation layer module after the first convolution.
            Defaults to ``nn.ReLU``.
        out_activation (Callable[..., nn.Module]): Activation layer module after the last convolution.
            Defaults to ``nn.Sigmoid``.
    """

    def __init__(
        self,
        channels: int,
        reduction: int = 16,
        mid_channels: int | None = None,
        round_mid: bool = False,
        use_conv: bool = True,
        mid_activation: Callable[..., nn.Module] = nn.ReLU,
        out_activation: Callable[..., nn.Module] = nn.Sigmoid,
    ):
        super().__init__()
        self.use_conv = use_conv
        if mid_channels is None:
            mid_channels = channels // reduction if not round_mid else round_channels(float(channels) / reduction)

        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        if use_conv:
            self.conv1 = nn.Conv2d(
                in_channels=channels,
                out_channels=mid_channels,
                kernel_size=1,
                stride=1,
                groups=1,
                bias=True,
            )
        else:
            self.fc1 = nn.Linear(in_features=channels, out_features=mid_channels)
        self.activ = mid_activation()
        if use_conv:
            self.conv2 = nn.Conv2d(
                in_channels=mid_channels,
                out_channels=channels,
                kernel_size=1,
                stride=1,
                groups=1,
                bias=True,
            )
        else:
            self.fc2 = nn.Linear(in_features=mid_channels, out_features=channels)
        self.sigmoid = out_activation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward."""
        w = self.pool(x)
        if not self.use_conv:
            w = w.view(x.size(0), -1)
        w = self.conv1(w) if self.use_conv else self.fc1(w)
        w = self.activ(w)
        w = self.conv2(w) if self.use_conv else self.fc2(w)
        w = self.sigmoid(w)
        if not self.use_conv:
            w = w.unsqueeze(2).unsqueeze(3)
        return x * w


class EffiDwsConvUnit(nn.Module):
    """EfficientNet specific depthwise separable conv block/unit with BatchNorms and activations at each conv.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        stride (int | tuple[int, int]): Strides of the second convolution layer.
        bn_eps (float): Small float added to variance in Batch norm.
        activation (Callable[..., nn.Module]): Activation layer module.
        tf_mode (bool): Whether to use TF-like mode.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int | tuple[int, int],
        bn_eps: float,
        activation: Callable[..., nn.Module],
        tf_mode: bool,
    ):
        super().__init__()
        self.tf_mode = tf_mode
        self.residual = (in_channels == out_channels) and (stride == 1)

        self.dw_conv = dwconv3x3_block(
            in_channels=in_channels,
            out_channels=in_channels,
            padding=(0 if tf_mode else 1),
            bn_eps=bn_eps,
            activation=activation,
        )
        self.se = SEBlock(channels=in_channels, reduction=4, mid_activation=activation)
        self.pw_conv = conv1x1_block(
            in_channels=in_channels,
            out_channels=out_channels,
            bn_eps=bn_eps,
            activation=None,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward."""
        if self.residual:
            identity = x
        if self.tf_mode:
            x = functional.pad(x, pad=calc_tf_padding(x, kernel_size=3))
        x = self.dw_conv(x)
        x = self.se(x)
        x = self.pw_conv(x)
        if self.residual:
            x = x + identity
        return x


class EffiInvResUnit(nn.Module):
    """EfficientNet inverted residual unit.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int | tuple[int, int]): Convolution window size.
        stride (int | tuple[int, int]): Strides of the second convolution layer.
        exp_factor (int): Factor for expansion of channels.
        se_factor (int): SE reduction factor for each unit.
        bn_eps (float): Small float added to variance in Batch norm.
        activation (Callable[..., nn.Module]): Activation layer module.
        tf_mode (bool): Whether to use TF-like mode.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int | tuple[int, int],
        exp_factor: int,
        se_factor: int,
        bn_eps: float,
        activation: Callable[..., nn.Module],
        tf_mode: bool,
    ):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.tf_mode = tf_mode
        self.residual = (in_channels == out_channels) and (stride == 1)
        self.use_se = se_factor > 0
        mid_channels = in_channels * exp_factor
        dwconv_block_fn = dwconv3x3_block if kernel_size == 3 else dwconv5x5_block

        self.conv1 = conv1x1_block(
            in_channels=in_channels,
            out_channels=mid_channels,
            bn_eps=bn_eps,
            activation=activation,
        )
        self.conv2 = dwconv_block_fn(
            in_channels=mid_channels,
            out_channels=mid_channels,
            stride=stride,
            padding=(0 if tf_mode else kernel_size // 2),
            bn_eps=bn_eps,
            activation=activation,
        )
        if self.use_se:
            self.se = SEBlock(
                channels=mid_channels,
                reduction=(exp_factor * se_factor),
                mid_activation=activation,
            )
        self.conv3 = conv1x1_block(
            in_channels=mid_channels,
            out_channels=out_channels,
            bn_eps=bn_eps,
            activation=None,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward."""
        if self.residual:
            identity = x
        x = self.conv1(x)
        if self.tf_mode:
            x = functional.pad(
                x,
                pad=calc_tf_padding(x, kernel_size=self.kernel_size, stride=self.stride),
            )
        x = self.conv2(x)
        if self.use_se:
            x = self.se(x)
        x = self.conv3(x)
        if self.residual:
            x = x + identity
        return x


class EffiInitBlock(nn.Module):
    """EfficientNet specific initial block.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        bn_eps (float): Small float added to variance in Batch norm.
        activation (Callable[..., nn.Module] | None): Activation layer module.
        tf_mode (bool): Whether to use TF-like mode.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        bn_eps: float,
        activation: Callable[..., nn.Module] | None,
        tf_mode: bool,
    ):
        super().__init__()
        self.tf_mode = tf_mode

        self.conv = conv3x3_block(
            in_channels=in_channels,
            out_channels=out_channels,
            stride=2,
            padding=(0 if tf_mode else 1),
            bn_eps=bn_eps,
            activation=activation,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward."""
        if self.tf_mode:
            x = functional.pad(x, pad=calc_tf_padding(x, kernel_size=3, stride=2))
        return self.conv(x)


class EfficientNet(nn.Module):
    """EfficientNet.

    Args:
        channels : list of list of int. Number of output channels for each unit.
        init_block_channels : int. Number of output channels for initial unit.
        final_block_channels : int. Number of output channels for the final block of the feature extractor.
        kernel_sizes : list of list of int. Number of kernel sizes for each unit.
        strides_per_stage : list int. Stride value for the first unit of each stage.
        expansion_factors : list of list of int. Number of expansion factors for each unit.
        tf_mode : bool, default False. Whether to use TF-like mode.
        bn_eps : float, default 1e-5. Small float added to variance in Batch norm.
        in_channels : int, default 3. Number of input channels.
        in_size : tuple of two ints, default (224, 224). Spatial size of the expected input image.
        pooling_type : str, default 'avg'. Pooling type to use.
        bn_eval : bool, default False. Whether to use BatchNorm eval mode.
        bn_frozen : bool, default False. Whether to freeze BatchNorm parameters.
        instance_norm_first : bool, default False. Whether to use instance normalization first.
    """

    def __init__(
        self,
        channels: list[list[int]],
        init_block_channels: int,
        final_block_channels: int,
        kernel_sizes: list[list[int]],
        strides_per_stage: list[int],
        expansion_factors: list[list[int]],
        tf_mode: bool = False,
        bn_eps: float = 1e-5,
        in_channels: int = 3,
        in_size: tuple[int, int] = (224, 224),
        pooling_type: str | None = "avg",
        bn_eval: bool = False,
        bn_frozen: bool = False,
        instance_norm_first: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.num_classes = 1000
        self.in_size = in_size
        self.input_IN = nn.InstanceNorm2d(3, affine=True) if instance_norm_first else None
        self.bn_eval = bn_eval
        self.bn_frozen = bn_frozen
        self.pooling_type = pooling_type
        self.num_features = self.num_head_features = final_block_channels
        activation = Swish
        self.features = nn.Sequential()
        self.features.add_module(
            "init_block",
            EffiInitBlock(
                in_channels=in_channels,
                out_channels=init_block_channels,
                bn_eps=bn_eps,
                activation=activation,
                tf_mode=tf_mode,
            ),
        )
        in_channels = init_block_channels
        for i, channels_per_stage in enumerate(channels):
            kernel_sizes_per_stage = kernel_sizes[i]
            expansion_factors_per_stage = expansion_factors[i]
            stage = nn.Sequential()
            for j, out_channels in enumerate(channels_per_stage):
                kernel_size = kernel_sizes_per_stage[j]
                expansion_factor = expansion_factors_per_stage[j]
                stride = strides_per_stage[i] if (j == 0) else 1
                if i == 0:
                    stage.add_module(
                        f"unit{j + 1}",
                        EffiDwsConvUnit(
                            in_channels=in_channels,
                            out_channels=out_channels,
                            stride=stride,
                            bn_eps=bn_eps,
                            activation=activation,
                            tf_mode=tf_mode,
                        ),
                    )
                else:
                    stage.add_module(
                        f"unit{j + 1}",
                        EffiInvResUnit(
                            in_channels=in_channels,
                            out_channels=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            exp_factor=expansion_factor,
                            se_factor=4,
                            bn_eps=bn_eps,
                            activation=activation,
                            tf_mode=tf_mode,
                        ),
                    )
                in_channels = out_channels
            self.features.add_module(f"stage{i+1}", stage)
            # activation = activation if self.loss == 'softmax': else lambda: nn.PReLU(init=0.25)
        self.features.add_module(
            "final_block",
            conv1x1_block(
                in_channels=in_channels,
                out_channels=final_block_channels,
                bn_eps=bn_eps,
                activation=activation,
            ),
        )
        self._init_params()

    def _init_params(self) -> None:
        for module in self.named_modules():
            if isinstance(module, nn.Conv2d):
                init.kaiming_uniform_(module.weight)
                if module.bias is not None:
                    init.constant_(module.bias, 0)

    def forward(
        self,
        x: torch.Tensor,
        **kwargs,
    ) -> tuple | list[torch.Tensor] | torch.Tensor:
        """Forward."""
        if self.input_IN is not None:
            x = self.input_IN(x)

        y = self.features(x)
        return (y,)


EFFICIENTNET_VERSION = Literal["b0", "b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8"]


[docs] class EfficientNetBackbone: """EfficientNetBackbone class represents the backbone architecture of EfficientNet models. Attributes: EFFICIENTNET_CFG (ClassVar[dict[str, Any]]): A dictionary containing configuration parameters for different versions of EfficientNet. init_block_channels (ClassVar[int]): The number of channels in the initial block of the backbone. layers (ClassVar[list[int]]): A list specifying the number of layers in each stage of the backbone. downsample (ClassVar[list[int]]): A list specifying whether downsampling is applied. channels_per_layers (ClassVar[list[int]]): A list specifying the number of channels. expansion_factors_per_layers (ClassVar[list[int]]): A list specifying the expansion factor. kernel_sizes_per_layers (ClassVar[list[int]]): A list specifying the kernel size in each stage of the backbone. strides_per_stage (ClassVar[list[int]]): A list specifying the stride in each stage of the backbone. final_block_channels (ClassVar[int]): The number of channels in the final block of the backbone. """ EFFICIENTNET_CFG: ClassVar[dict[str, Any]] = { "b0": { "input_size": (224, 224), "depth_factor": 1.0, "width_factor": 1.0, }, "b1": { "input_size": (240, 240), "depth_factor": 1.1, "width_factor": 1.0, }, "b2": { "input_size": (260, 260), "depth_factor": 1.2, "width_factor": 1.1, }, "b3": { "input_size": (300, 300), "depth_factor": 1.4, "width_factor": 1.2, }, "b4": { "input_size": (380, 380), "depth_factor": 1.8, "width_factor": 1.4, }, "b5": { "input_size": (456, 456), "depth_factor": 2.2, "width_factor": 1.6, }, "b6": { "input_size": (528, 528), "depth_factor": 2.6, "width_factor": 1.8, }, "b7": { "input_size": (600, 600), "depth_factor": 3.1, "width_factor": 2.0, }, "b8": { "input_size": (672, 672), "depth_factor": 3.6, "width_factor": 2.2, }, } init_block_channels: ClassVar[int] = 32 layers: ClassVar[list[int]] = [1, 2, 2, 3, 3, 4, 1] downsample: ClassVar[list[int]] = [1, 1, 1, 1, 0, 1, 0] channels_per_layers: ClassVar[list[int]] = [16, 24, 40, 80, 112, 192, 320] expansion_factors_per_layers: ClassVar[list[int]] = [1, 6, 6, 6, 6, 6, 6] kernel_sizes_per_layers: ClassVar[list[int]] = [3, 3, 5, 3, 5, 5, 3] strides_per_stage: ClassVar[list[int]] = [1, 2, 2, 2, 1, 2, 1] final_block_channels: ClassVar[int] = 1280 def __new__( cls, version: EFFICIENTNET_VERSION, input_size: tuple[int, int] | None = None, pretrained: bool = True, **kwargs, ) -> EfficientNet: """Create a new instance of the EfficientNet class. Args: version (EFFICIENTNET_VERSION): The version of EfficientNet to use. input_size (tuple[int, int] | None, optional): The input size of the model. Defaults to None. pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. **kwargs: Additional keyword arguments to be passed to the EfficientNet constructor. Returns: EfficientNet: The created EfficientNet model instance. """ origin_input_size, depth_factor, width_factor = cls.EFFICIENTNET_CFG[version].values() input_size = input_size or origin_input_size effnet_layers = [int(math.ceil(li * depth_factor)) for li in cls.layers] channels_per_layers = [round_channels(ci * width_factor) for ci in cls.channels_per_layers] from functools import reduce channels: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], zip(channels_per_layers, effnet_layers, cls.downsample), [], ) kernel_sizes: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], zip(cls.kernel_sizes_per_layers, effnet_layers, cls.downsample), [], ) expansion_factors: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], zip(cls.expansion_factors_per_layers, effnet_layers, cls.downsample), [], ) strides_per_stage: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], zip(cls.strides_per_stage, effnet_layers, cls.downsample), [], ) strides_per_stage = [si[0] for si in strides_per_stage] init_block_channels = round_channels(cls.init_block_channels * width_factor) final_block_channels = cls.final_block_channels if width_factor > 1.0: final_block_channels = round_channels(final_block_channels * width_factor) model = EfficientNet( channels=channels, init_block_channels=init_block_channels, final_block_channels=final_block_channels, kernel_sizes=kernel_sizes, strides_per_stage=strides_per_stage, expansion_factors=expansion_factors, tf_mode=False, bn_eps=1e-5, in_size=input_size, **kwargs, ) if pretrained: cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" download_model(net=model, model_name=f"efficientnet_{version}", local_model_store_dir_path=str(cache_dir)) print(f"Download model weight in {cache_dir!s}") return model