Source code for otx.algorithms.segmentation.adapters.mmseg.models.backbones.litehrnet

"""HRNet network modules for base backbone.

Modified from:
- https://github.com/HRNet/Lite-HRNet
"""

# Copyright (c) 2018-2020 Open-MMLab.
# SPDX-License-Identifier: Apache-2.0
#
# Copyright (c) 2021 DeLightCMU
# SPDX-License-Identifier: Apache-2.0
#
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#


import mmcv
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import (
    ConvModule,
    build_conv_layer,
    build_norm_layer,
    constant_init,
    normal_init,
)
from mmcv.runner import BaseModule, load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
from mmseg.models.builder import BACKBONES
from torch import nn

from otx.algorithms.segmentation.adapters.mmseg.models.utils import (
    AsymmetricPositionAttentionModule,
    IterativeAggregator,
    LocalAttentionModule,
    channel_shuffle,
)
from otx.utils.logger import get_logger

logger = get_logger()


# pylint: disable=invalid-name, too-many-lines, too-many-instance-attributes, too-many-locals, too-many-arguments
# pylint: disable=unused-argument, consider-using-enumerate
class NeighbourSupport(nn.Module):
    """Neighbour support module."""

    def __init__(
        self,
        channels,
        kernel_size=3,
        key_ratio=8,
        value_ratio=8,
        conv_cfg=None,
        norm_cfg=None,
    ):
        super().__init__()

        self.in_channels = channels
        self.key_channels = int(channels / key_ratio)
        self.value_channels = int(channels / value_ratio)
        self.kernel_size = kernel_size

        self.key = nn.Sequential(
            ConvModule(
                in_channels=self.in_channels,
                out_channels=self.key_channels,
                kernel_size=1,
                stride=1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=dict(type="ReLU"),
            ),
            ConvModule(
                self.key_channels,
                self.key_channels,
                kernel_size=self.kernel_size,
                stride=1,
                padding=(self.kernel_size - 1) // 2,
                groups=self.key_channels,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=None,
            ),
            ConvModule(
                in_channels=self.key_channels,
                out_channels=self.kernel_size * self.kernel_size,
                kernel_size=1,
                stride=1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=None,
            ),
        )
        self.value = nn.Sequential(
            ConvModule(
                in_channels=self.in_channels,
                out_channels=self.value_channels,
                kernel_size=1,
                stride=1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=None,
            ),
            nn.Unfold(kernel_size=self.kernel_size, stride=1, padding=1),
        )
        self.out_conv = ConvModule(
            in_channels=self.value_channels,
            out_channels=self.in_channels,
            kernel_size=1,
            stride=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None,
        )

    def forward(self, x):
        """Forward."""
        h, w = [int(_) for _ in x.size()[-2:]]

        key = self.key(x).view(-1, 1, self.kernel_size**2, h, w)
        weights = torch.softmax(key, dim=2)

        value = self.value(x).view(-1, self.value_channels, self.kernel_size**2, h, w)
        y = torch.sum(weights * value, dim=2)
        y = self.out_conv(y)

        out = x + y

        return out


class CrossResolutionWeighting(nn.Module):
    """Cross resolution weighting."""

    def __init__(
        self,
        channels,
        ratio=16,
        conv_cfg=None,
        norm_cfg=None,
        act_cfg=(dict(type="ReLU"), dict(type="Sigmoid")),
    ):
        super().__init__()

        if isinstance(act_cfg, dict):
            act_cfg = (act_cfg, act_cfg)
        assert len(act_cfg) == 2
        assert mmcv.is_tuple_of(act_cfg, dict)

        self.channels = channels
        total_channel = sum(channels)

        self.conv1 = ConvModule(
            in_channels=total_channel,
            out_channels=int(total_channel / ratio),
            kernel_size=1,
            stride=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg[0],
        )
        self.conv2 = ConvModule(
            in_channels=int(total_channel / ratio),
            out_channels=total_channel,
            kernel_size=1,
            stride=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg[1],
        )

    def forward(self, x):
        """Forward."""
        min_size = [int(_) for _ in x[-1].size()[-2:]]

        out = [F.adaptive_avg_pool2d(s, min_size) for s in x[:-1]] + [x[-1]]
        out = torch.cat(out, dim=1)
        out = self.conv1(out)
        out = self.conv2(out)
        out = torch.split(out, self.channels, dim=1)
        out = [s * F.interpolate(a, size=s.size()[-2:], mode="nearest") for s, a in zip(x, out)]

        return out


class SpatialWeighting(nn.Module):
    """Spatial weighting."""

    def __init__(
        self,
        channels,
        ratio=16,
        conv_cfg=None,
        act_cfg=(dict(type="ReLU"), dict(type="Sigmoid")),
        **kwargs,
    ):
        super().__init__()

        if isinstance(act_cfg, dict):
            act_cfg = (act_cfg, act_cfg)
        assert len(act_cfg) == 2
        assert mmcv.is_tuple_of(act_cfg, dict)

        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = ConvModule(
            in_channels=channels,
            out_channels=int(channels / ratio),
            kernel_size=1,
            stride=1,
            conv_cfg=conv_cfg,
            act_cfg=act_cfg[0],
        )
        self.conv2 = ConvModule(
            in_channels=int(channels / ratio),
            out_channels=channels,
            kernel_size=1,
            stride=1,
            conv_cfg=conv_cfg,
            act_cfg=act_cfg[1],
        )

    def forward(self, x):
        """Forward."""
        out = self.global_avgpool(x)
        out = self.conv1(out)
        out = self.conv2(out)

        return x * out


class SpatialWeightingV2(nn.Module):
    """The original repo: https://github.com/DeLightCMU/PSA."""

    def __init__(
        self,
        channels,
        ratio=16,
        conv_cfg=None,
        norm_cfg=None,
        enable_norm=False,
        **kwargs,
    ):
        super().__init__()

        self.in_channels = channels
        self.internal_channels = int(channels / ratio)

        # channel-only branch
        self.v_channel = ConvModule(
            in_channels=self.in_channels,
            out_channels=self.internal_channels,
            kernel_size=1,
            stride=1,
            bias=False,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg if enable_norm else None,
            act_cfg=None,
        )
        self.q_channel = ConvModule(
            in_channels=self.in_channels,
            out_channels=1,
            kernel_size=1,
            stride=1,
            bias=False,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg if enable_norm else None,
            act_cfg=None,
        )
        self.out_channel = ConvModule(
            in_channels=self.internal_channels,
            out_channels=self.in_channels,
            kernel_size=1,
            stride=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=dict(type="Sigmoid"),
        )

        # spatial-only branch
        self.v_spatial = ConvModule(
            in_channels=self.in_channels,
            out_channels=self.internal_channels,
            kernel_size=1,
            stride=1,
            bias=False,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg if enable_norm else None,
            act_cfg=None,
        )
        self.q_spatial = ConvModule(
            in_channels=self.in_channels,
            out_channels=self.internal_channels,
            kernel_size=1,
            stride=1,
            bias=False,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg if enable_norm else None,
            act_cfg=None,
        )
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)

    def _channel_weighting(self, x):
        h, w = [int(_) for _ in x.size()[-2:]]

        v = self.v_channel(x).view(-1, self.internal_channels, h * w)

        q = self.q_channel(x).view(-1, h * w, 1)
        q = torch.softmax(q, dim=1)

        y = torch.matmul(v, q)
        y = y.view(-1, self.internal_channels, 1, 1)
        y = self.out_channel(y)

        out = x * y

        return out

    def _spatial_weighting(self, x):
        h, w = [int(_) for _ in x.size()[-2:]]

        v = self.v_spatial(x)
        v = v.view(-1, self.internal_channels, h * w)

        q = self.q_spatial(x)
        q = self.global_avgpool(q)
        q = torch.softmax(q, dim=1)
        q = q.view(-1, 1, self.internal_channels)

        y = torch.matmul(q, v)
        y = y.view(-1, 1, h, w)
        y = torch.sigmoid(y)

        out = x * y

        return out

    def forward(self, x):
        """Forward."""
        y_channel = self._channel_weighting(x)
        y_spatial = self._spatial_weighting(x)
        out = y_channel + y_spatial

        return out


class ConditionalChannelWeighting(nn.Module):
    """Conditional channel weighting module."""

    def __init__(
        self,
        in_channels,
        stride,
        reduce_ratio,
        conv_cfg=None,
        norm_cfg=None,
        with_cp=False,
        dropout=None,
        weighting_module_version="v1",
        neighbour_weighting=False,
        dw_ksize=3,
    ):
        super().__init__()

        if norm_cfg is None:
            norm_cfg = dict(type="BN")

        self.with_cp = with_cp
        self.stride = stride
        assert stride in [1, 2]

        spatial_weighting_module = SpatialWeighting if weighting_module_version == "v1" else SpatialWeightingV2
        branch_channels = [channel // 2 for channel in in_channels]

        self.cross_resolution_weighting = CrossResolutionWeighting(
            branch_channels, ratio=reduce_ratio, conv_cfg=conv_cfg, norm_cfg=norm_cfg
        )
        self.depthwise_convs = nn.ModuleList(
            [
                ConvModule(
                    channel,
                    channel,
                    kernel_size=dw_ksize,
                    stride=self.stride,
                    padding=dw_ksize // 2,
                    groups=channel,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=None,
                )
                for channel in branch_channels
            ]
        )
        self.spatial_weighting = nn.ModuleList(
            [
                spatial_weighting_module(
                    channels=channel,
                    ratio=4,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    enable_norm=True,
                )
                for channel in branch_channels
            ]
        )

        self.neighbour_weighting = None
        if neighbour_weighting:
            self.neighbour_weighting = nn.ModuleList(
                [
                    NeighbourSupport(
                        channel,
                        kernel_size=3,
                        key_ratio=8,
                        value_ratio=4,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                    )
                    for channel in branch_channels
                ]
            )

        self.dropout = None
        if dropout is not None and dropout > 0:
            self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in branch_channels])

    def _inner_forward(self, x):
        x = [s.chunk(2, dim=1) for s in x]
        x1 = [s[0] for s in x]
        x2 = [s[1] for s in x]

        x2 = self.cross_resolution_weighting(x2)
        x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]

        if self.neighbour_weighting is not None:
            x2 = [nw(s) for s, nw in zip(x2, self.neighbour_weighting)]

        x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]

        if self.dropout is not None:
            x2 = [dropout(s) for s, dropout in zip(x2, self.dropout)]

        out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
        out = [channel_shuffle(s, 2) for s in out]

        return out

    def forward(self, x):
        """Forward."""
        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(self._inner_forward, x)
        else:
            out = self._inner_forward(x)

        return out


class Stem(nn.Module):
    """Stem."""

    def __init__(
        self,
        in_channels,
        stem_channels,
        out_channels,
        expand_ratio,
        conv_cfg=None,
        norm_cfg=None,
        with_cp=False,
        strides=(2, 2),
        extra_stride=False,
        input_norm=False,
    ):
        super().__init__()

        if norm_cfg is None:
            norm_cfg = dict(type="BN")

        assert isinstance(strides, (tuple, list))
        assert len(strides) == 2

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.with_cp = with_cp

        self.input_norm = None
        if input_norm:
            self.input_norm = nn.InstanceNorm2d(in_channels)

        self.conv1 = ConvModule(
            in_channels=in_channels,
            out_channels=stem_channels,
            kernel_size=3,
            stride=strides[0],
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=dict(type="ReLU"),
        )

        self.conv2 = None
        if extra_stride:
            self.conv2 = ConvModule(
                in_channels=stem_channels,
                out_channels=stem_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=dict(type="ReLU"),
            )

        mid_channels = int(round(stem_channels * expand_ratio))
        branch_channels = stem_channels // 2
        if stem_channels == self.out_channels:
            inc_channels = self.out_channels - branch_channels
        else:
            inc_channels = self.out_channels - stem_channels

        self.branch1 = nn.Sequential(
            ConvModule(
                branch_channels,
                branch_channels,
                kernel_size=3,
                stride=strides[1],
                padding=1,
                groups=branch_channels,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=None,
            ),
            ConvModule(
                branch_channels,
                inc_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=dict(type="ReLU"),
            ),
        )

        self.expand_conv = ConvModule(
            branch_channels,
            mid_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=dict(type="ReLU"),
        )
        self.depthwise_conv = ConvModule(
            mid_channels,
            mid_channels,
            kernel_size=3,
            stride=strides[1],
            padding=1,
            groups=mid_channels,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None,
        )
        self.linear_conv = ConvModule(
            mid_channels,
            branch_channels if stem_channels == self.out_channels else stem_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=dict(type="ReLU"),
        )

    def _inner_forward(self, x):
        if self.input_norm is not None:
            x = self.input_norm(x)

        x = self.conv1(x)
        if self.conv2 is not None:
            x = self.conv2(x)

        x1, x2 = x.chunk(2, dim=1)

        x1 = self.branch1(x1)

        x2 = self.expand_conv(x2)
        x2 = self.depthwise_conv(x2)
        x2 = self.linear_conv(x2)

        out = torch.cat((x1, x2), dim=1)
        out = channel_shuffle(out, 2)

        return out

    def forward(self, x):
        """Forward."""
        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(self._inner_forward, x)
        else:
            out = self._inner_forward(x)

        return out


class StemV2(nn.Module):
    """StemV2."""

    def __init__(
        self,
        in_channels,
        stem_channels,
        out_channels,
        expand_ratio,
        conv_cfg=None,
        norm_cfg=None,
        with_cp=False,
        num_stages=1,
        strides=(2, 2),
        extra_stride=False,
        input_norm=False,
    ):
        super().__init__()

        if norm_cfg is None:
            norm_cfg = dict(type="BN")

        assert num_stages > 0
        assert isinstance(strides, (tuple, list))
        assert len(strides) == 1 + num_stages

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.with_cp = with_cp
        self.num_stages = num_stages

        self.input_norm = None
        if input_norm:
            self.input_norm = nn.InstanceNorm2d(in_channels)

        self.conv1 = ConvModule(
            in_channels=in_channels,
            out_channels=stem_channels,
            kernel_size=3,
            stride=strides[0],
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=dict(type="ReLU"),
        )

        self.conv2 = None
        if extra_stride:
            self.conv2 = ConvModule(
                in_channels=stem_channels,
                out_channels=stem_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=dict(type="ReLU"),
            )

        mid_channels = int(round(stem_channels * expand_ratio))
        internal_branch_channels = stem_channels // 2
        out_branch_channels = self.out_channels // 2

        self.branch1, self.branch2 = nn.ModuleList(), nn.ModuleList()
        for stage in range(1, num_stages + 1):
            self.branch1.append(
                nn.Sequential(
                    ConvModule(
                        internal_branch_channels,
                        internal_branch_channels,
                        kernel_size=3,
                        stride=strides[stage],
                        padding=1,
                        groups=internal_branch_channels,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=None,
                    ),
                    ConvModule(
                        internal_branch_channels,
                        out_branch_channels if stage == num_stages else internal_branch_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=dict(type="ReLU"),
                    ),
                )
            )

            self.branch2.append(
                nn.Sequential(
                    ConvModule(
                        internal_branch_channels,
                        mid_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=dict(type="ReLU"),
                    ),
                    ConvModule(
                        mid_channels,
                        mid_channels,
                        kernel_size=3,
                        stride=strides[stage],
                        padding=1,
                        groups=mid_channels,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=None,
                    ),
                    ConvModule(
                        mid_channels,
                        out_branch_channels if stage == num_stages else internal_branch_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=dict(type="ReLU"),
                    ),
                )
            )

    def _inner_forward(self, x):
        if self.input_norm is not None:
            x = self.input_norm(x)

        y = self.conv1(x)
        if self.conv2 is not None:
            y = self.conv2(y)

        out_list = [y]
        for stage in range(self.num_stages):
            y1, y2 = y.chunk(2, dim=1)

            y1 = self.branch1[stage](y1)
            y2 = self.branch2[stage](y2)

            y = torch.cat((y1, y2), dim=1)
            y = channel_shuffle(y, 2)
            out_list.append(y)

        return out_list

    def forward(self, x):
        """Forward."""
        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(self._inner_forward, x)
        else:
            out = self._inner_forward(x)

        return out


class ShuffleUnit(nn.Module):
    """InvertedResidual block for ShuffleNetV2 backbone.

    Args:
        in_channels (int): The input channels of the block.
        out_channels (int): The output channels of the block.
        stride (int): Stride of the 3x3 convolution layer. Default: 1
        conv_cfg (dict): Config dict for convolution layer.
            Default: None, which means using conv2d.
        norm_cfg (dict): Config dict for normalization layer.
            Default: dict(type='BN').
        act_cfg (dict): Config dict for activation layer.
            Default: dict(type='ReLU').
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Default: False.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        stride=1,
        conv_cfg=None,
        norm_cfg=None,
        act_cfg=None,
        with_cp=False,
    ):
        super().__init__()

        if norm_cfg is None:
            norm_cfg = dict(type="BN")
        if act_cfg is None:
            act_cfg = dict(type="ReLU")

        self.stride = stride
        self.with_cp = with_cp

        branch_features = out_channels // 2
        if self.stride == 1:
            assert in_channels == branch_features * 2, (
                f"in_channels ({in_channels}) should equal to "
                f"branch_features * 2 ({branch_features * 2}) "
                "when stride is 1"
            )

        if in_channels != branch_features * 2:
            assert self.stride != 1, (
                f"stride ({self.stride}) should not equal 1 when " f"in_channels != branch_features * 2"
            )

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                ConvModule(
                    in_channels,
                    in_channels,
                    kernel_size=3,
                    stride=self.stride,
                    padding=1,
                    groups=in_channels,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=None,
                ),
                ConvModule(
                    in_channels,
                    branch_features,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                ),
            )

        self.branch2 = nn.Sequential(
            ConvModule(
                in_channels if (self.stride > 1) else branch_features,
                branch_features,
                kernel_size=1,
                stride=1,
                padding=0,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg,
            ),
            ConvModule(
                branch_features,
                branch_features,
                kernel_size=3,
                stride=self.stride,
                padding=1,
                groups=branch_features,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=None,
            ),
            ConvModule(
                branch_features,
                branch_features,
                kernel_size=1,
                stride=1,
                padding=0,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg,
            ),
        )

    def _inner_forward(self, x):
        if self.stride > 1:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        else:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)

        out = channel_shuffle(out, 2)

        return out

    def forward(self, x):
        """Forward."""
        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(self._inner_forward, x)
        else:
            out = self._inner_forward(x)

        return out


class LiteHRModule(nn.Module):
    """LiteHR module."""

    def __init__(
        self,
        num_branches,
        num_blocks,
        in_channels,
        reduce_ratio,
        module_type,
        multiscale_output=False,
        with_fuse=True,
        conv_cfg=None,
        norm_cfg=None,
        with_cp=False,
        dropout=None,
        weighting_module_version="v1",
        neighbour_weighting=False,
    ):
        super().__init__()

        if norm_cfg is None:
            norm_cfg = dict(type="BN")
        self._check_branches(num_branches, in_channels)

        self.in_channels = in_channels
        self.num_branches = num_branches

        self.module_type = module_type
        self.multiscale_output = multiscale_output
        self.with_fuse = with_fuse
        self.norm_cfg = norm_cfg
        self.conv_cfg = conv_cfg
        self.with_cp = with_cp
        self.weighting_module_version = weighting_module_version
        self.neighbour_weighting = neighbour_weighting

        if self.module_type == "LITE":
            self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio, dropout=dropout)
        elif self.module_type == "NAIVE":
            self.layers = self._make_naive_branches(num_branches, num_blocks)

        if self.with_fuse:
            self.fuse_layers = self._make_fuse_layers()
            self.relu = nn.ReLU()

    @staticmethod
    def _check_branches(num_branches, in_channels):
        """Check input to avoid ValueError."""

        if num_branches != len(in_channels):
            error_msg = f"NUM_BRANCHES({num_branches}) != NUM_INCHANNELS({len(in_channels)})"
            raise ValueError(error_msg)

    def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1, dropout=None):
        layers = []
        for _ in range(num_blocks):
            layers.append(
                ConditionalChannelWeighting(
                    self.in_channels,
                    stride=stride,
                    reduce_ratio=reduce_ratio,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    with_cp=self.with_cp,
                    dropout=dropout,
                    weighting_module_version=self.weighting_module_version,
                    neighbour_weighting=self.neighbour_weighting,
                )
            )

        return nn.Sequential(*layers)

    def _make_one_branch(self, branch_index, num_blocks, stride=1):
        """Make one branch."""

        layers = [
            ShuffleUnit(
                self.in_channels[branch_index],
                self.in_channels[branch_index],
                stride=stride,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=dict(type="ReLU"),
                with_cp=self.with_cp,
            )
        ]
        for _ in range(1, num_blocks):
            layers.append(
                ShuffleUnit(
                    self.in_channels[branch_index],
                    self.in_channels[branch_index],
                    stride=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    act_cfg=dict(type="ReLU"),
                    with_cp=self.with_cp,
                )
            )

        return nn.Sequential(*layers)

    def _make_naive_branches(self, num_branches, num_blocks):
        """Make branches."""

        branches = []
        for i in range(num_branches):
            branches.append(self._make_one_branch(i, num_blocks))

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        """Make fuse layer."""

        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        in_channels = self.in_channels
        num_out_branches = num_branches if self.multiscale_output else 1

        fuse_layers = []
        for i in range(num_out_branches):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(
                        nn.Sequential(
                            build_conv_layer(
                                self.conv_cfg,
                                in_channels[j],
                                in_channels[i],
                                kernel_size=1,
                                stride=1,
                                padding=0,
                                bias=False,
                            ),
                            build_norm_layer(self.norm_cfg, in_channels[i])[1],
                        )
                    )
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv_downsamples = []
                    for k in range(i - j):
                        if k == i - j - 1:
                            conv_downsamples.append(
                                nn.Sequential(
                                    build_conv_layer(
                                        self.conv_cfg,
                                        in_channels[j],
                                        in_channels[j],
                                        kernel_size=3,
                                        stride=2,
                                        padding=1,
                                        groups=in_channels[j],
                                        bias=False,
                                    ),
                                    build_norm_layer(self.norm_cfg, in_channels[j])[1],
                                    build_conv_layer(
                                        self.conv_cfg,
                                        in_channels[j],
                                        in_channels[i],
                                        kernel_size=1,
                                        stride=1,
                                        padding=0,
                                        bias=False,
                                    ),
                                    build_norm_layer(self.norm_cfg, in_channels[i])[1],
                                )
                            )
                        else:
                            conv_downsamples.append(
                                nn.Sequential(
                                    build_conv_layer(
                                        self.conv_cfg,
                                        in_channels[j],
                                        in_channels[j],
                                        kernel_size=3,
                                        stride=2,
                                        padding=1,
                                        groups=in_channels[j],
                                        bias=False,
                                    ),
                                    build_norm_layer(self.norm_cfg, in_channels[j])[1],
                                    build_conv_layer(
                                        self.conv_cfg,
                                        in_channels[j],
                                        in_channels[j],
                                        kernel_size=1,
                                        stride=1,
                                        padding=0,
                                        bias=False,
                                    ),
                                    build_norm_layer(self.norm_cfg, in_channels[j])[1],
                                    nn.ReLU(inplace=True),
                                )
                            )
                    fuse_layer.append(nn.Sequential(*conv_downsamples))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def forward(self, x):
        """Forward function."""

        if self.num_branches == 1:
            return [self.layers[0](x[0])]

        if self.module_type == "LITE":
            out = self.layers(x)
        elif self.module_type == "NAIVE":
            for i in range(self.num_branches):
                x[i] = self.layers[i](x[i])
            out = x

        if self.with_fuse:
            out_fuse = []
            for i in range(len(self.fuse_layers)):
                y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
                for j in range(self.num_branches):
                    if i == j:
                        fuse_y = out[j]
                    else:
                        fuse_y = self.fuse_layers[i][j](out[j])

                    if fuse_y.size()[-2:] != y.size()[-2:]:
                        fuse_y = F.interpolate(fuse_y, size=y.size()[-2:], mode="nearest")

                    y += fuse_y

                out_fuse.append(self.relu(y))

            out = out_fuse
        elif not self.multiscale_output:
            out = [out[0]]

        return out


[docs] @BACKBONES.register_module() class LiteHRNet(BaseModule): """Lite-HRNet backbone. `High-Resolution Representations for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_ Args: extra (dict): detailed configuration for each stage of HRNet. in_channels (int): Number of input image channels. Default: 3. conv_cfg (dict): dictionary to construct and config conv layer. norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. zero_init_residual (bool): whether to use zero init for last norm layer in resblocks to let them behave as identity. """ def __init__( self, extra, in_channels=3, conv_cfg=None, norm_cfg=None, norm_eval=False, with_cp=False, zero_init_residual=False, dropout=None, init_cfg=None, ): super().__init__(init_cfg=init_cfg) if norm_cfg is None: norm_cfg = dict(type="BN") self.extra = extra self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.stem = Stem( in_channels, input_norm=self.extra["stem"]["input_norm"], stem_channels=self.extra["stem"]["stem_channels"], out_channels=self.extra["stem"]["out_channels"], expand_ratio=self.extra["stem"]["expand_ratio"], strides=self.extra["stem"]["strides"], extra_stride=self.extra["stem"]["extra_stride"], conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, ) self.enable_stem_pool = self.extra["stem"].get("out_pool", False) if self.enable_stem_pool: self.stem_pool = nn.AvgPool2d(kernel_size=3, stride=2) self.num_stages = self.extra["num_stages"] self.stages_spec = self.extra["stages_spec"] num_channels_last = [ self.stem.out_channels, ] for i in range(self.num_stages): num_channels = self.stages_spec["num_channels"][i] num_channels = [num_channels[i] for i in range(len(num_channels))] setattr( self, f"transition{i}", self._make_transition_layer(num_channels_last, num_channels), ) stage, num_channels_last = self._make_stage( self.stages_spec, i, num_channels, multiscale_output=True, dropout=dropout, ) setattr(self, f"stage{i}", stage) self.out_modules = None if self.extra.get("out_modules") is not None: out_modules = [] in_modules_channels, out_modules_channels = num_channels_last[-1], None if self.extra["out_modules"]["conv"]["enable"]: out_modules_channels = self.extra["out_modules"]["conv"]["channels"] out_modules.append( ConvModule( in_channels=in_modules_channels, out_channels=out_modules_channels, kernel_size=1, stride=1, padding=0, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type="ReLU"), ) ) in_modules_channels = out_modules_channels if self.extra["out_modules"]["position_att"]["enable"]: out_modules.append( AsymmetricPositionAttentionModule( in_channels=in_modules_channels, key_channels=self.extra["out_modules"]["position_att"]["key_channels"], value_channels=self.extra["out_modules"]["position_att"]["value_channels"], psp_size=self.extra["out_modules"]["position_att"]["psp_size"], conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, ) ) if self.extra["out_modules"]["local_att"]["enable"]: out_modules.append( LocalAttentionModule( num_channels=in_modules_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, ) ) if len(out_modules) > 0: self.out_modules = nn.Sequential(*out_modules) num_channels_last.append(in_modules_channels) self.add_stem_features = self.extra.get("add_stem_features", False) if self.add_stem_features: self.stem_transition = nn.Sequential( ConvModule( self.stem.out_channels, self.stem.out_channels, kernel_size=3, stride=1, padding=1, groups=self.stem.out_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None, ), ConvModule( self.stem.out_channels, num_channels_last[0], kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type="ReLU"), ), ) num_channels_last = [num_channels_last[0]] + num_channels_last self.with_aggregator = self.extra.get("out_aggregator") and self.extra["out_aggregator"]["enable"] if self.with_aggregator: self.aggregator = IterativeAggregator( in_channels=num_channels_last, min_channels=self.extra["out_aggregator"].get("min_channels", None), conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, ) def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): """Make transition layer.""" num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append( nn.Sequential( build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_pre_layer[i], kernel_size=3, stride=1, padding=1, groups=num_channels_pre_layer[i], bias=False, ), build_norm_layer(self.norm_cfg, num_channels_pre_layer[i])[1], build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_cur_layer[i], kernel_size=1, stride=1, padding=0, bias=False, ), build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[1], nn.ReLU(), ) ) else: transition_layers.append(None) else: conv_downsamples = [] for j in range(i + 1 - num_branches_pre): in_channels = num_channels_pre_layer[-1] out_channels = num_channels_cur_layer[i] if j == i - num_branches_pre else in_channels conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, in_channels, kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False, ), build_norm_layer(self.norm_cfg, in_channels)[1], build_conv_layer( self.conv_cfg, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, ), build_norm_layer(self.norm_cfg, out_channels)[1], nn.ReLU(), ) ) transition_layers.append(nn.Sequential(*conv_downsamples)) return nn.ModuleList(transition_layers) def _make_stage( self, stages_spec, stage_index, in_channels, multiscale_output=True, dropout=None, ): num_modules = stages_spec["num_modules"][stage_index] num_branches = stages_spec["num_branches"][stage_index] num_blocks = stages_spec["num_blocks"][stage_index] reduce_ratio = stages_spec["reduce_ratios"][stage_index] with_fuse = stages_spec["with_fuse"][stage_index] module_type = stages_spec["module_type"][stage_index] weighting_module_version = stages_spec.get("weighting_module_version", "v1") neighbour_weighting = stages_spec.get("neighbour_weighting", False) modules = [] for i in range(num_modules): # multi_scale_output is only used last module if not multiscale_output and i == num_modules - 1: reset_multiscale_output = False else: reset_multiscale_output = True modules.append( LiteHRModule( num_branches, num_blocks, in_channels, reduce_ratio, module_type, multiscale_output=reset_multiscale_output, with_fuse=with_fuse, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, with_cp=self.with_cp, dropout=dropout, weighting_module_version=weighting_module_version, neighbour_weighting=neighbour_weighting, ) ) in_channels = modules[-1].in_channels return nn.Sequential(*modules), in_channels
[docs] def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if isinstance(pretrained, str): load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): normal_init(m, std=0.001) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): constant_init(m.norm3, 0) elif isinstance(m, BasicBlock): constant_init(m.norm2, 0) else: raise TypeError("pretrained must be a str or None")
[docs] def forward(self, x): """Forward function.""" stem_outputs = self.stem(x) y_x2 = y_x4 = stem_outputs # y_x2, y_x4 = stem_outputs[-2:] y = y_x4 if self.enable_stem_pool: y = self.stem_pool(y) y_list = [y] for i in range(self.num_stages): transition_modules = getattr(self, f"transition{i}") stage_inputs = [] for j in range(self.stages_spec["num_branches"][i]): if transition_modules[j]: if j >= len(y_list): stage_inputs.append(transition_modules[j](y_list[-1])) else: stage_inputs.append(transition_modules[j](y_list[j])) else: stage_inputs.append(y_list[j]) stage_module = getattr(self, f"stage{i}") y_list = stage_module(stage_inputs) if self.out_modules is not None: y_list.append(self.out_modules(y_list[-1])) if self.add_stem_features: y_stem = self.stem_transition(y_x2) y_list = [y_stem] + y_list out = y_list if self.with_aggregator: out = self.aggregator(out) if self.extra.get("add_input", False): out = [x] + out return out
[docs] def train(self, mode=True): """Convert the model into training mode.""" super().train(mode) if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()