# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""X3D backbone implementation."""
from __future__ import annotations
import math
from typing import Callable
import torch.utils.checkpoint as cp
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from otx.algo.modules.activation import Swish, build_activation_layer
from otx.algo.modules.conv_module import Conv3dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.weight_init import constant_init, kaiming_init
class SEModule(nn.Module):
"""Implementation of SqueezeExcitation module."""
def __init__(self, channels: int, reduction: float):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.bottleneck = self._round_width(channels, reduction)
self.fc1 = nn.Conv3d(channels, self.bottleneck, kernel_size=1, padding=0)
self.relu = nn.ReLU()
self.fc2 = nn.Conv3d(self.bottleneck, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid()
@staticmethod
def _round_width(width: int, multiplier: float, min_width: int = 8, divisor: int = 8) -> int:
"""Round width of filters based on width multiplier."""
width = int(width * multiplier)
min_width = min_width or divisor
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
if width_out < 0.9 * width:
width_out += divisor
return int(width_out)
def forward(self, x: Tensor) -> Tensor:
"""Defines the computation performed at every call.
Args:
x (Tensor): The input data.
Returns:
Tensor: The output of the module.
"""
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class BlockX3D(nn.Module):
"""BlockX3D 3d building block for X3D.
Args:
inplanes (int): Number of channels for the input in first conv3d layer.
planes (int): Number of channels produced by some norm/conv3d layers.
outplanes (int): Number of channels produced by final the conv3d layer.
spatial_stride (int): Spatial stride in the conv3d layer. Default: 1.
downsample (nn.Module | None): Downsample layer. Default: None.
se_ratio (float | None): The reduction ratio of squeeze and excitation
unit. If set as None, it means not using SE unit. Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(
self,
inplanes: int,
planes: int,
outplanes: int,
spatial_stride: int = 1,
downsample: nn.Module | None = None,
se_ratio: float | None = None,
use_swish: bool = True,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
):
super().__init__()
self.inplanes = inplanes
self.planes = planes
self.outplanes = outplanes
self.spatial_stride = spatial_stride
self.downsample = downsample
self.se_ratio = se_ratio
self.use_swish = use_swish
self.normalization = normalization
self.activation = activation
self.with_cp = with_cp
self.conv1 = Conv3dModule(
in_channels=inplanes,
out_channels=planes,
kernel_size=1,
stride=1,
padding=0,
bias=False,
normalization=build_norm_layer(normalization, num_features=planes),
activation=build_activation_layer(activation),
)
# Here we use the channel-wise conv
self.conv2 = Conv3dModule(
in_channels=planes,
out_channels=planes,
kernel_size=3,
stride=(1, self.spatial_stride, self.spatial_stride),
padding=1,
groups=planes,
bias=False,
normalization=build_norm_layer(normalization, num_features=planes),
activation=None,
)
self.swish = Swish()
self.conv3 = Conv3dModule(
in_channels=planes,
out_channels=outplanes,
kernel_size=1,
stride=1,
padding=0,
bias=False,
normalization=build_norm_layer(normalization, num_features=outplanes),
activation=None,
)
if self.se_ratio is not None:
self.se_module = SEModule(planes, self.se_ratio)
self.relu = self.activation() if self.activation else nn.ReLU(inplace=True)
def forward(self, x: Tensor) -> Tensor:
"""Defines the computation performed at every call."""
def _inner_forward(x: Tensor) -> Tensor:
"""Forward wrapper for utilizing checkpoint."""
identity = x
out = self.conv1(x)
out = self.conv2(out)
if self.se_ratio is not None:
out = self.se_module(out)
out = self.swish(out)
out = self.conv3(out)
if self.downsample is not None:
identity = self.downsample(x)
return out + identity
out = cp.checkpoint(_inner_forward, x) if self.with_cp and x.requires_grad else _inner_forward(x)
return self.relu(out)
# We do not support initialize with 2D pretrain weight for X3D
[docs]
class X3DBackbone(nn.Module):
"""X3D backbone. https://arxiv.org/pdf/2004.04730.pdf.
Args:
gamma_w (float): Global channel width expansion factor. Default: 1.
gamma_b (float): Bottleneck channel width expansion factor. Default: 1.
gamma_d (float): Network depth expansion factor. Default: 1.
pretrained (str | None): Name of pretrained model. Default: None.
in_channels (int): Channel num of input features. Default: 3.
num_stages (int): Resnet stages. Default: 4.
spatial_strides (Sequence[int]):
Spatial strides of residual blocks of each stage.
Default: ``(1, 2, 2, 2)``.
frozen_stages (int): Stages to be frozen (all param fixed). If set to
-1, it means not freezing any parameters. Default: -1.
se_style (str): The style of inserting SE modules into BlockX3D, 'half'
denotes insert into half of the blocks, while 'all' denotes insert
into all blocks. Default: 'half'.
se_ratio (float | None): The reduction ratio of squeeze and excitation
unit. If set as None, it means not using SE unit. Default: 1 / 16.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var). Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool):
Whether to use zero initialization for residual block,
Default: True.
kwargs (dict, optional): Key arguments for "make_res_layer".
"""
def __init__(
self,
gamma_w: float = 1.0,
gamma_b: float = 1.0,
gamma_d: float = 1.0,
pretrained: str | None = None,
in_channels: int = 3,
num_stages: int = 4,
spatial_strides: tuple[int, int, int, int] = (2, 2, 2, 2),
frozen_stages: int = -1,
se_style: str = "half",
se_ratio: float = 1 / 16,
use_swish: bool = True,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = True,
**kwargs,
):
super().__init__()
self.gamma_w = gamma_w
self.gamma_b = gamma_b
self.gamma_d = gamma_d
self.pretrained = pretrained
self.in_channels = in_channels
# Hard coded, can be changed by gamma_w
self.base_channels = 24
self.stage_blocks = [1, 2, 5, 3]
# apply parameters gamma_w and gamma_d
self.base_channels = self._round_width(self.base_channels, self.gamma_w)
self.stage_blocks = [self._round_repeats(x, self.gamma_d) for x in self.stage_blocks]
self.num_stages = num_stages
if num_stages < 1 or num_stages > 4:
msg = "num_stages for X3DBackbone should be 1<=num_stages<=4."
raise ValueError(msg)
self.spatial_strides = spatial_strides
if len(spatial_strides) != num_stages:
msg = "number of spatial_strides should be same to num_stages."
raise ValueError(msg)
self.frozen_stages = frozen_stages
self.se_style = se_style
if self.se_style not in ["all", "half"]:
msg = f"se_style should be 'all' or 'half', but got {self.se_style}."
raise ValueError(msg)
self.se_ratio = se_ratio
if self.se_ratio and self.se_ratio <= 0:
msg = f"se_ratio should be larger than 0, but got {self.se_ratio}."
raise ValueError(msg)
self.use_swish = use_swish
self.normalization = normalization
self.activation = activation
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
self.block = BlockX3D
self.stage_blocks = self.stage_blocks[:num_stages]
self.layer_inplanes = self.base_channels
self._make_stem_layer()
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
spatial_stride = spatial_strides[i]
inplanes = self.base_channels * 2**i
planes = int(inplanes * self.gamma_b)
res_layer = self.make_res_layer(
self.block,
self.layer_inplanes,
inplanes,
planes,
num_blocks,
spatial_stride=spatial_stride,
se_style=self.se_style,
se_ratio=self.se_ratio,
use_swish=self.use_swish,
normalization=self.normalization,
activation=self.activation,
with_cp=with_cp,
**kwargs,
)
self.layer_inplanes = inplanes
layer_name = f"layer{i + 1}"
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = self.base_channels * 2 ** (len(self.stage_blocks) - 1)
self.conv5 = Conv3dModule(
self.feat_dim,
int(self.feat_dim * self.gamma_b),
kernel_size=1,
stride=1,
padding=0,
bias=False,
normalization=build_norm_layer(self.normalization, num_features=int(self.feat_dim * self.gamma_b)),
activation=build_activation_layer(self.activation),
)
self.feat_dim = int(self.feat_dim * self.gamma_b)
@staticmethod
def _round_width(width: int, multiplier: float, min_depth: int = 8, divisor: int = 8) -> int:
"""Round width of filters based on width multiplier."""
if not multiplier:
return width
width = int(width * multiplier)
min_depth = min_depth or divisor
new_filters = max(min_depth, int(width + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * width:
new_filters += divisor
return int(new_filters)
@staticmethod
def _round_repeats(repeats: int, multiplier: float) -> int:
"""Round number of layers based on depth multiplier."""
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
# the module is parameterized with gamma_b
# no temporal_stride
[docs]
def make_res_layer(
self,
block: nn.Module,
layer_inplanes: int,
inplanes: int,
planes: int,
blocks: int,
spatial_stride: int = 1,
se_style: str = "half",
se_ratio: float | None = None,
use_swish: bool = True,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
**kwargs,
) -> nn.Module:
"""Build residual layer for ResNet3D.
Args:
block (nn.Module): Residual module to be built.
layer_inplanes (int): Number of channels for the input feature
of the res layer.
inplanes (int): Number of channels for the input feature in each
block, which equals to base_channels * gamma_w.
planes (int): Number of channels for the output feature in each
block, which equals to base_channel * gamma_w * gamma_b.
blocks (int): Number of residual blocks.
spatial_stride (int): Spatial strides in residual and conv layers.
Default: 1.
se_style (str): The style of inserting SE modules into BlockX3D,
'half' denotes insert into half of the blocks, while 'all'
denotes insert into all blocks. Default: 'half'.
se_ratio (float | None): The reduction ratio of squeeze and
excitation unit. If set as None, it means not using SE unit.
Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
with_cp (bool | None): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
Returns:
nn.Module: A residual layer for the given config.
"""
downsample = None
if spatial_stride != 1 or layer_inplanes != inplanes:
downsample = Conv3dModule(
layer_inplanes,
inplanes,
kernel_size=1,
stride=(1, spatial_stride, spatial_stride),
padding=0,
bias=False,
normalization=build_norm_layer(normalization, num_features=inplanes),
activation=None,
)
use_se = [False] * blocks
if self.se_style == "all":
use_se = [True] * blocks
elif self.se_style == "half":
use_se = [i % 2 == 0 for i in range(blocks)]
else:
raise NotImplementedError
layers = []
layers.append(
block(
layer_inplanes,
planes,
inplanes,
spatial_stride=spatial_stride,
downsample=downsample,
se_ratio=se_ratio if use_se[0] else None,
use_swish=use_swish,
normalization=normalization,
activation=activation,
with_cp=with_cp,
**kwargs,
),
)
for i in range(1, blocks):
layers.append( # noqa: PERF401
block(
inplanes,
planes,
inplanes,
spatial_stride=1,
se_ratio=se_ratio if use_se[i] else None,
use_swish=use_swish,
normalization=normalization,
activation=activation,
with_cp=with_cp,
**kwargs,
),
)
return nn.Sequential(*layers)
def _make_stem_layer(self) -> None:
"""Construct the stem layers consists of a conv+norm+act module and a pooling layer."""
self.conv1_s = Conv3dModule(
self.in_channels,
self.base_channels,
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
bias=False,
normalization=None,
activation=None,
)
self.conv1_t = Conv3dModule(
self.base_channels,
self.base_channels,
kernel_size=(5, 1, 1),
stride=(1, 1, 1),
padding=(2, 0, 0),
groups=self.base_channels,
bias=False,
normalization=build_norm_layer(self.normalization, num_features=self.base_channels),
activation=build_activation_layer(self.activation),
)
def _freeze_stages(self) -> None:
"""Prevent all the parameters from being optimized before ``self.frozen_stages``."""
if self.frozen_stages >= 0:
self.conv1_s.eval()
self.conv1_t.eval()
for param in self.conv1_s.parameters():
param.requires_grad = False
for param in self.conv1_t.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f"layer{i}")
m.eval()
for param in m.parameters():
param.requires_grad = False
[docs]
def init_weights(self) -> None:
"""Initiate the parameters either from existing checkpoint or from scratch."""
if isinstance(self.pretrained, str):
load_checkpoint(self, self.pretrained, strict=False)
elif self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv3d):
kaiming_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, BlockX3D):
constant_init(m.conv3.bn, 0)
else:
msg = "pretrained must be a str or None"
raise TypeError(msg)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The feature of the input
samples extracted by the backbone.
"""
x = self.conv1_s(x)
x = self.conv1_t(x)
for layer_name in self.res_layers:
res_layer = getattr(self, layer_name)
x = res_layer(x)
return self.conv5(x)
[docs]
def train(self, mode: bool = True) -> None:
"""Set the optimization status when training."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()