Source code for otx.algo.action_classification.backbones.x3d

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.

"""X3D backbone implementation."""

from __future__ import annotations

import math
from typing import Callable

import torch.utils.checkpoint as cp
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm

from otx.algo.modules.activation import Swish, build_activation_layer
from otx.algo.modules.conv_module import Conv3dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.weight_init import constant_init, kaiming_init


class SEModule(nn.Module):
    """Implementation of SqueezeExcitation module."""

    def __init__(self, channels: int, reduction: float):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.bottleneck = self._round_width(channels, reduction)
        self.fc1 = nn.Conv3d(channels, self.bottleneck, kernel_size=1, padding=0)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv3d(self.bottleneck, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    @staticmethod
    def _round_width(width: int, multiplier: float, min_width: int = 8, divisor: int = 8) -> int:
        """Round width of filters based on width multiplier."""
        width = int(width * multiplier)
        min_width = min_width or divisor
        width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
        if width_out < 0.9 * width:
            width_out += divisor
        return int(width_out)

    def forward(self, x: Tensor) -> Tensor:
        """Defines the computation performed at every call.

        Args:
            x (Tensor): The input data.

        Returns:
            Tensor: The output of the module.
        """
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return module_input * x


class BlockX3D(nn.Module):
    """BlockX3D 3d building block for X3D.

    Args:
        inplanes (int): Number of channels for the input in first conv3d layer.
        planes (int): Number of channels produced by some norm/conv3d layers.
        outplanes (int): Number of channels produced by final the conv3d layer.
        spatial_stride (int): Spatial stride in the conv3d layer. Default: 1.
        downsample (nn.Module | None): Downsample layer. Default: None.
        se_ratio (float | None): The reduction ratio of squeeze and excitation
            unit. If set as None, it means not using SE unit. Default: None.
        use_swish (bool): Whether to use swish as the activation function
            before and after the 3x3x3 conv. Default: True.
        normalization (Callable[..., nn.Module] | None): Normalization layer module.
            Defaults to None.
        activation (Callable[..., nn.Module] | None): Activation layer module.
            Defaults to ``nn.ReLU``.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Default: False.
    """

    def __init__(
        self,
        inplanes: int,
        planes: int,
        outplanes: int,
        spatial_stride: int = 1,
        downsample: nn.Module | None = None,
        se_ratio: float | None = None,
        use_swish: bool = True,
        normalization: Callable[..., nn.Module] | None = None,
        activation: Callable[..., nn.Module] | None = nn.ReLU,
        with_cp: bool = False,
    ):
        super().__init__()

        self.inplanes = inplanes
        self.planes = planes
        self.outplanes = outplanes
        self.spatial_stride = spatial_stride
        self.downsample = downsample
        self.se_ratio = se_ratio
        self.use_swish = use_swish
        self.normalization = normalization
        self.activation = activation
        self.with_cp = with_cp

        self.conv1 = Conv3dModule(
            in_channels=inplanes,
            out_channels=planes,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
            normalization=build_norm_layer(normalization, num_features=planes),
            activation=build_activation_layer(activation),
        )
        # Here we use the channel-wise conv
        self.conv2 = Conv3dModule(
            in_channels=planes,
            out_channels=planes,
            kernel_size=3,
            stride=(1, self.spatial_stride, self.spatial_stride),
            padding=1,
            groups=planes,
            bias=False,
            normalization=build_norm_layer(normalization, num_features=planes),
            activation=None,
        )

        self.swish = Swish()

        self.conv3 = Conv3dModule(
            in_channels=planes,
            out_channels=outplanes,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
            normalization=build_norm_layer(normalization, num_features=outplanes),
            activation=None,
        )

        if self.se_ratio is not None:
            self.se_module = SEModule(planes, self.se_ratio)

        self.relu = self.activation() if self.activation else nn.ReLU(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        """Defines the computation performed at every call."""

        def _inner_forward(x: Tensor) -> Tensor:
            """Forward wrapper for utilizing checkpoint."""
            identity = x

            out = self.conv1(x)
            out = self.conv2(out)
            if self.se_ratio is not None:
                out = self.se_module(out)

            out = self.swish(out)

            out = self.conv3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            return out + identity

        out = cp.checkpoint(_inner_forward, x) if self.with_cp and x.requires_grad else _inner_forward(x)
        return self.relu(out)


# We do not support initialize with 2D pretrain weight for X3D
[docs] class X3DBackbone(nn.Module): """X3D backbone. https://arxiv.org/pdf/2004.04730.pdf. Args: gamma_w (float): Global channel width expansion factor. Default: 1. gamma_b (float): Bottleneck channel width expansion factor. Default: 1. gamma_d (float): Network depth expansion factor. Default: 1. pretrained (str | None): Name of pretrained model. Default: None. in_channels (int): Channel num of input features. Default: 3. num_stages (int): Resnet stages. Default: 4. spatial_strides (Sequence[int]): Spatial strides of residual blocks of each stage. Default: ``(1, 2, 2, 2)``. frozen_stages (int): Stages to be frozen (all param fixed). If set to -1, it means not freezing any parameters. Default: -1. se_style (str): The style of inserting SE modules into BlockX3D, 'half' denotes insert into half of the blocks, while 'all' denotes insert into all blocks. Default: 'half'. se_ratio (float | None): The reduction ratio of squeeze and excitation unit. If set as None, it means not using SE unit. Default: 1 / 16. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. normalization (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. activation (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``nn.ReLU``. norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze running stats (mean and var). Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. zero_init_residual (bool): Whether to use zero initialization for residual block, Default: True. kwargs (dict, optional): Key arguments for "make_res_layer". """ def __init__( self, gamma_w: float = 1.0, gamma_b: float = 1.0, gamma_d: float = 1.0, pretrained: str | None = None, in_channels: int = 3, num_stages: int = 4, spatial_strides: tuple[int, int, int, int] = (2, 2, 2, 2), frozen_stages: int = -1, se_style: str = "half", se_ratio: float = 1 / 16, use_swish: bool = True, normalization: Callable[..., nn.Module] | None = None, activation: Callable[..., nn.Module] | None = nn.ReLU, norm_eval: bool = False, with_cp: bool = False, zero_init_residual: bool = True, **kwargs, ): super().__init__() self.gamma_w = gamma_w self.gamma_b = gamma_b self.gamma_d = gamma_d self.pretrained = pretrained self.in_channels = in_channels # Hard coded, can be changed by gamma_w self.base_channels = 24 self.stage_blocks = [1, 2, 5, 3] # apply parameters gamma_w and gamma_d self.base_channels = self._round_width(self.base_channels, self.gamma_w) self.stage_blocks = [self._round_repeats(x, self.gamma_d) for x in self.stage_blocks] self.num_stages = num_stages if num_stages < 1 or num_stages > 4: msg = "num_stages for X3DBackbone should be 1<=num_stages<=4." raise ValueError(msg) self.spatial_strides = spatial_strides if len(spatial_strides) != num_stages: msg = "number of spatial_strides should be same to num_stages." raise ValueError(msg) self.frozen_stages = frozen_stages self.se_style = se_style if self.se_style not in ["all", "half"]: msg = f"se_style should be 'all' or 'half', but got {self.se_style}." raise ValueError(msg) self.se_ratio = se_ratio if self.se_ratio and self.se_ratio <= 0: msg = f"se_ratio should be larger than 0, but got {self.se_ratio}." raise ValueError(msg) self.use_swish = use_swish self.normalization = normalization self.activation = activation self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.block = BlockX3D self.stage_blocks = self.stage_blocks[:num_stages] self.layer_inplanes = self.base_channels self._make_stem_layer() self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): spatial_stride = spatial_strides[i] inplanes = self.base_channels * 2**i planes = int(inplanes * self.gamma_b) res_layer = self.make_res_layer( self.block, self.layer_inplanes, inplanes, planes, num_blocks, spatial_stride=spatial_stride, se_style=self.se_style, se_ratio=self.se_ratio, use_swish=self.use_swish, normalization=self.normalization, activation=self.activation, with_cp=with_cp, **kwargs, ) self.layer_inplanes = inplanes layer_name = f"layer{i + 1}" self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self.feat_dim = self.base_channels * 2 ** (len(self.stage_blocks) - 1) self.conv5 = Conv3dModule( self.feat_dim, int(self.feat_dim * self.gamma_b), kernel_size=1, stride=1, padding=0, bias=False, normalization=build_norm_layer(self.normalization, num_features=int(self.feat_dim * self.gamma_b)), activation=build_activation_layer(self.activation), ) self.feat_dim = int(self.feat_dim * self.gamma_b) @staticmethod def _round_width(width: int, multiplier: float, min_depth: int = 8, divisor: int = 8) -> int: """Round width of filters based on width multiplier.""" if not multiplier: return width width = int(width * multiplier) min_depth = min_depth or divisor new_filters = max(min_depth, int(width + divisor / 2) // divisor * divisor) if new_filters < 0.9 * width: new_filters += divisor return int(new_filters) @staticmethod def _round_repeats(repeats: int, multiplier: float) -> int: """Round number of layers based on depth multiplier.""" if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) # the module is parameterized with gamma_b # no temporal_stride
[docs] def make_res_layer( self, block: nn.Module, layer_inplanes: int, inplanes: int, planes: int, blocks: int, spatial_stride: int = 1, se_style: str = "half", se_ratio: float | None = None, use_swish: bool = True, normalization: Callable[..., nn.Module] | None = None, activation: Callable[..., nn.Module] | None = nn.ReLU, with_cp: bool = False, **kwargs, ) -> nn.Module: """Build residual layer for ResNet3D. Args: block (nn.Module): Residual module to be built. layer_inplanes (int): Number of channels for the input feature of the res layer. inplanes (int): Number of channels for the input feature in each block, which equals to base_channels * gamma_w. planes (int): Number of channels for the output feature in each block, which equals to base_channel * gamma_w * gamma_b. blocks (int): Number of residual blocks. spatial_stride (int): Spatial strides in residual and conv layers. Default: 1. se_style (str): The style of inserting SE modules into BlockX3D, 'half' denotes insert into half of the blocks, while 'all' denotes insert into all blocks. Default: 'half'. se_ratio (float | None): The reduction ratio of squeeze and excitation unit. If set as None, it means not using SE unit. Default: None. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. normalization (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. activation (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``nn.ReLU``. with_cp (bool | None): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. Returns: nn.Module: A residual layer for the given config. """ downsample = None if spatial_stride != 1 or layer_inplanes != inplanes: downsample = Conv3dModule( layer_inplanes, inplanes, kernel_size=1, stride=(1, spatial_stride, spatial_stride), padding=0, bias=False, normalization=build_norm_layer(normalization, num_features=inplanes), activation=None, ) use_se = [False] * blocks if self.se_style == "all": use_se = [True] * blocks elif self.se_style == "half": use_se = [i % 2 == 0 for i in range(blocks)] else: raise NotImplementedError layers = [] layers.append( block( layer_inplanes, planes, inplanes, spatial_stride=spatial_stride, downsample=downsample, se_ratio=se_ratio if use_se[0] else None, use_swish=use_swish, normalization=normalization, activation=activation, with_cp=with_cp, **kwargs, ), ) for i in range(1, blocks): layers.append( # noqa: PERF401 block( inplanes, planes, inplanes, spatial_stride=1, se_ratio=se_ratio if use_se[i] else None, use_swish=use_swish, normalization=normalization, activation=activation, with_cp=with_cp, **kwargs, ), ) return nn.Sequential(*layers)
def _make_stem_layer(self) -> None: """Construct the stem layers consists of a conv+norm+act module and a pooling layer.""" self.conv1_s = Conv3dModule( self.in_channels, self.base_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False, normalization=None, activation=None, ) self.conv1_t = Conv3dModule( self.base_channels, self.base_channels, kernel_size=(5, 1, 1), stride=(1, 1, 1), padding=(2, 0, 0), groups=self.base_channels, bias=False, normalization=build_norm_layer(self.normalization, num_features=self.base_channels), activation=build_activation_layer(self.activation), ) def _freeze_stages(self) -> None: """Prevent all the parameters from being optimized before ``self.frozen_stages``.""" if self.frozen_stages >= 0: self.conv1_s.eval() self.conv1_t.eval() for param in self.conv1_s.parameters(): param.requires_grad = False for param in self.conv1_t.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f"layer{i}") m.eval() for param in m.parameters(): param.requires_grad = False
[docs] def init_weights(self) -> None: """Initiate the parameters either from existing checkpoint or from scratch.""" if isinstance(self.pretrained, str): load_checkpoint(self, self.pretrained, strict=False) elif self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv3d): kaiming_init(m) elif isinstance(m, _BatchNorm): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, BlockX3D): constant_init(m.conv3.bn, 0) else: msg = "pretrained must be a str or None" raise TypeError(msg)
[docs] def forward(self, x: Tensor) -> Tensor: """Defines the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The feature of the input samples extracted by the backbone. """ x = self.conv1_s(x) x = self.conv1_t(x) for layer_name in self.res_layers: res_layer = getattr(self, layer_name) x = res_layer(x) return self.conv5(x)
[docs] def train(self, mode: bool = True) -> None: """Set the optimization status when training.""" super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()