# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""MSCAN backbone for SegNext model."""
from __future__ import annotations
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar
import torch
from torch import nn
from torch.nn import SyncBatchNorm
from otx.algo.modules import build_norm_layer
from otx.algo.modules.base_module import BaseModule
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http
if TYPE_CHECKING:
from torch import Tensor
def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample.
Args:
x (Tensor): The input tensor.
drop_prob (float): Probability of the path to be zeroed. Default: 0.0
training (bool): The running mode. Default: False
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
# handle tensors with different dimensions, not just 4D tensors.
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
return x.div(keep_prob) * random_tensor.floor()
class DropPath(nn.Module):
"""DropPath."""
def __init__(self, drop_prob: float = 0.1):
"""Drop paths (Stochastic Depth) per sample.
Args:
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
"""
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
return drop_path(x, self.drop_prob, self.training)
class Mlp(BaseModule):
"""Multi Layer Perceptron (MLP) Module.
Args:
in_features (int): The dimension of input features.
hidden_features (int): The dimension of hidden features.
Defaults: None.
out_features (int): The dimension of output features.
Defaults: None.
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.GELU``.
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
"""
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
activation: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
"""Initializes the MLP module."""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features)
self.act = activation()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return self.drop(x)
class StemConv(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): The dimension of input channels.
out_channels (int): The dimension of output channels.
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.GELU``.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
activation: Callable[..., nn.Module] = nn.GELU,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, SyncBatchNorm, requires_grad=True),
) -> None:
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
build_norm_layer(normalization, num_features=out_channels // 2)[1],
activation(),
nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
build_norm_layer(normalization, num_features=out_channels)[1],
)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
"""Forward function."""
x = self.proj(x)
_, _, h, w = x.size()
x = x.flatten(2).transpose(1, 2)
return x, h, w
class MSCAAttention(BaseModule):
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA)."""
def __init__(
self,
channels: int,
kernel_sizes: list[Any] = [5, [1, 7], [1, 11], [1, 21]], # noqa: B006
paddings: list[Any] = [2, [0, 3], [0, 5], [0, 10]], # noqa: B006
) -> None:
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
Args:
channels (int): The dimension of channels.
kernel_sizes (List[Union[int, List[int]]]): The size of attention kernel.
Defaults: [5, [1, 7], [1, 11], [1, 21]].
paddings (List[Union[int, List[int]]]): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
"""
super().__init__()
self.conv0 = nn.Conv2d(channels, channels, kernel_size=kernel_sizes[0], padding=paddings[0], groups=channels)
for i, (kernel_size, padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
kernel_size_ = [kernel_size, kernel_size[::-1]]
padding_ = [padding, padding[::-1]]
conv_name = [f"conv{i}_1", f"conv{i}_2"]
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_, conv_name):
self.add_module(i_conv, nn.Conv2d(channels, channels, tuple(i_kernel), padding=i_pad, groups=channels))
self.conv3 = nn.Conv2d(channels, channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
u = x.clone()
attn = self.conv0(x)
# Multi-Scale Feature extraction
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
# Channel Mixing
attn = self.conv3(attn)
# Convolutional Attention
return attn * u
class MSCASpatialAttention(BaseModule):
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
Args:
in_channels (int): The number of input channels.
attention_kernel_sizes (List[Union[int, List[int]]]): The size of attention kernels.
attention_kernel_paddings (List[Union[int, List[int]]]): The paddings of attention kernels.
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.GELU``.
"""
def __init__(
self,
in_channels: int,
attention_kernel_sizes: list[int | list[int]] = [5, [1, 7], [1, 11], [1, 21]], # noqa: B006
attention_kernel_paddings: list[int | list[int]] = [2, [0, 3], [0, 5], [0, 10]], # noqa: B006
activation: Callable[..., nn.Module] = nn.GELU,
) -> None:
"""Init the MSCASpatialAttention module."""
super().__init__()
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1) # type: nn.Conv2d
self.activation = activation() # type: nn.Module
self.spatial_gating_unit = MSCAAttention(in_channels, attention_kernel_sizes, attention_kernel_paddings) # type: MSCAAttention
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1) # type: nn.Conv2d
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
return x + shorcut
class MSCABlock(BaseModule):
"""Basic Multi-Scale Convolutional Attention Block.
It leverage the large kernel attention (LKA) mechanism to build both channel and spatial
attention. In each branch, it uses two depth-wise strip convolutions to
approximate standard depth-wise convolutions with large kernels. The kernel
size for each branch is set to 7, 11, and 21, respectively.
Args:
channels (int): The number of input channels.
attention_kernel_sizes (List[Union[int, List[int]]]): The size of attention kernels.
attention_kernel_paddings (List[Union[int, List[int]]]): The paddings of attention kernels.
mlp_ratio (float): The ratio of the number of hidden units in the MLP to the number of input channels.
drop (float): The dropout rate.
drop_path (float): The dropout rate for the path.
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.GELU``.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``.
"""
def __init__(
self,
channels: int,
attention_kernel_sizes: list[int | list[int]] = [5, [1, 7], [1, 11], [1, 21]], # noqa: B006
attention_kernel_paddings: list[int | list[int]] = [2, [0, 3], [0, 5], [0, 10]], # noqa: B006
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_path: float = 0.0,
activation: Callable[..., nn.Module] = nn.GELU,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, SyncBatchNorm, requires_grad=True),
) -> None:
"""Initialize a MSCABlock."""
super().__init__()
self.norm1 = build_norm_layer(normalization, num_features=channels)[1] # type: nn.Module
self.attn = MSCASpatialAttention(
channels,
attention_kernel_sizes,
attention_kernel_paddings,
activation,
) # type: MSCAAttention
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # type: nn.Module
self.norm2 = build_norm_layer(normalization, num_features=channels)[1] # type: nn.Module
mlp_hidden_channels = int(channels * mlp_ratio) # type: int
self.mlp = Mlp(
in_features=channels,
hidden_features=mlp_hidden_channels,
activation=activation,
drop=drop,
) # type: Mlp
layer_scale_init_value = 1e-2 # type: float
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(channels), requires_grad=True) # type: nn.Parameter
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(channels), requires_grad=True) # type: nn.Parameter
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
"""Forward function."""
b, n, c = x.shape
x = x.permute(0, 2, 1).view(b, c, h, w)
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
return x.view(b, c, n).permute(0, 2, 1)
class OverlapPatchEmbed(BaseModule):
"""Image to Patch Embedding.
Args:
patch_size (int, optional): The patch size. Defaults to 7.
stride (int, optional): Stride of the convolutional layer. Defaults to 4.
in_channels (int, optional): The number of input channels. Defaults to 3.
embed_dim (int, optional): The dimensions of embedding. Defaults to 768.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``.
"""
def __init__(
self,
patch_size: int = 7,
stride: int = 4,
in_channels: int = 3,
embed_dim: int = 768,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, SyncBatchNorm, requires_grad=True),
):
"""Initializes the OverlapPatchEmbed module."""
super().__init__()
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=patch_size // 2)
self.norm = build_norm_layer(normalization, num_features=embed_dim)[1]
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
"""Forward function."""
x = self.proj(x)
_, _, h, w = x.shape
x = self.norm(x)
x = x.flatten(2).transpose(1, 2)
return x, h, w
class MSCANModule(nn.Module):
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
This backbone is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Args:
in_channels (int): The number of input channels. Defaults to 3.
embed_dims (List[int]): Embedding dimension. Defaults to [64, 128, 256, 512].
mlp_ratios (List[int]): Ratio of mlp hidden dim to embedding dim. Defaults to [4, 4, 4, 4].
drop_rate (float): Dropout rate. Defaults to 0.0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
depths (List[int]): Depths of each Swin Transformer stage. Defaults to [3, 4, 6, 3].
num_stages (int): MSCAN stages. Defaults to 4.
attention_kernel_sizes (List[Union[int, List[int]]]): Size of attention kernel in
Attention Module (Figure 2(b) of original paper). Defaults to [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (List[Union[int, List[int]]]): Size of attention paddings
in Attention Module (Figure 2(b) of original paper). Defaults to [2, [0, 3], [0, 5], [0, 10]].
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.GELU``.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``.
init_cfg (Optional[Union[Dict[str, str], List[Dict[str, str]]]]): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels: int = 3,
embed_dims: list[int] = [64, 128, 320, 512], # noqa: B006
mlp_ratios: list[int] = [8, 8, 4, 4], # noqa: B006
drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
depths: list[int] = [3, 4, 6, 3], # noqa: B006
num_stages: int = 4,
attention_kernel_sizes: list[int | list[int]] = [5, [1, 7], [1, 11], [1, 21]], # noqa: B006
attention_kernel_paddings: list[int | list[int]] = [2, [0, 3], [0, 5], [0, 10]], # noqa: B006
activation: Callable[..., nn.Module] = nn.GELU,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True),
pretrained_weights: str | None = None,
) -> None:
"""Initialize a MSCAN backbone."""
super().__init__()
self.depths = depths
self.num_stages = num_stages
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = StemConv(in_channels, embed_dims[0], normalization=normalization)
else:
patch_embed = OverlapPatchEmbed(
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
normalization=normalization,
)
block = nn.ModuleList(
[
MSCABlock(
channels=embed_dims[i],
attention_kernel_sizes=attention_kernel_sizes,
attention_kernel_paddings=attention_kernel_paddings,
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
activation=activation,
normalization=normalization,
)
for j in range(depths[i])
],
)
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
if pretrained_weights is not None:
self.load_pretrained_weights(pretrained_weights)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Forward function."""
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, h, w = patch_embed(x)
for blk in block:
x = blk(x, h, w)
x = norm(x)
x = x.reshape(x.shape[0], h, w, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "") -> None:
"""Initialize weights."""
checkpoint = None
if isinstance(pretrained, str) and Path(pretrained).exists():
checkpoint = torch.load(pretrained, "cpu")
print(f"init weight - {pretrained}")
elif pretrained is not None:
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir)
print(f"init weight - {pretrained}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint, prefix=prefix)
[docs]
class MSCAN:
"""MSCAN backbone factory."""
MSCAN_CFG: ClassVar[dict[str, Any]] = {
"segnext_tiny": {
"depths": [3, 3, 5, 2],
"embed_dims": [32, 64, 160, 256],
"pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth",
},
"segnext_small": {
"depths": [2, 2, 4, 2],
"pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth",
},
"segnext_base": {
"depths": [3, 3, 12, 3],
"pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth",
},
}
def __new__(cls, model_name: str) -> MSCANModule:
"""Constructor for MSCAN backbone."""
if model_name not in cls.MSCAN_CFG:
msg = f"model type '{model_name}' is not supported"
raise KeyError(msg)
return MSCANModule(**cls.MSCAN_CFG[model_name])