# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""HRNet network modules for base backbone.
Modified from:
- https://github.com/HRNet/Lite-HRNet
"""
from __future__ import annotations
from functools import partial
from pathlib import Path
from typing import Any, Callable, ClassVar
import torch
import torch.utils.checkpoint as cp
from torch import nn
from torch.nn import functional
from otx.algo.modules import Conv2dModule, build_activation_layer, build_norm_layer
from otx.algo.segmentation.modules import (
channel_shuffle,
)
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http
class NeighbourSupport(nn.Module):
"""Neighbour support module.
Args:
channels (int): Number of input channels.
kernel_size (int): Kernel size for convolutional layers. Default is 3.
key_ratio (int): Ratio of input channels to key channels. Default is 8.
value_ratio (int): Ratio of input channels to value channels. Default is 8.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
"""
def __init__(
self,
channels: int,
kernel_size: int = 3,
key_ratio: int = 8,
value_ratio: int = 8,
normalization: Callable[..., nn.Module] | None = None,
) -> None:
super().__init__()
self.in_channels = channels
self.key_channels = int(channels / key_ratio)
self.value_channels = int(channels / value_ratio)
self.kernel_size = kernel_size
self.key = nn.Sequential(
Conv2dModule(
in_channels=self.in_channels,
out_channels=self.key_channels,
kernel_size=1,
stride=1,
normalization=build_norm_layer(normalization, num_features=self.key_channels),
activation=build_activation_layer(nn.ReLU),
),
Conv2dModule(
self.key_channels,
self.key_channels,
kernel_size=self.kernel_size,
stride=1,
padding=(self.kernel_size - 1) // 2,
groups=self.key_channels,
normalization=build_norm_layer(normalization, num_features=self.key_channels),
activation=None,
),
Conv2dModule(
in_channels=self.key_channels,
out_channels=self.kernel_size * self.kernel_size,
kernel_size=1,
stride=1,
normalization=build_norm_layer(
normalization,
num_features=self.kernel_size * self.kernel_size,
),
activation=None,
),
)
self.value = nn.Sequential(
Conv2dModule(
in_channels=self.in_channels,
out_channels=self.value_channels,
kernel_size=1,
stride=1,
normalization=build_norm_layer(normalization, num_features=self.value_channels),
activation=None,
),
nn.Unfold(kernel_size=self.kernel_size, stride=1, padding=1),
)
self.out_conv = Conv2dModule(
in_channels=self.value_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
normalization=build_norm_layer(normalization, num_features=self.in_channels),
activation=None,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
h, w = (int(_) for _ in x.size()[-2:])
key = self.key(x).view(-1, 1, self.kernel_size**2, h, w)
weights = torch.softmax(key, dim=2)
value = self.value(x).view(-1, self.value_channels, self.kernel_size**2, h, w)
y = torch.sum(weights * value, dim=2)
y = self.out_conv(y)
return x + y
class CrossResolutionWeighting(nn.Module):
"""Cross resolution weighting.
Args:
channels (list[int]): Number of channels for each stage.
ratio (int): Reduction ratio of the bottleneck block.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | tuple[Callable[..., nn.Module], Callable[..., nn.Module]]): \
Activation layer module or a tuple of activation layer modules.
Defaults to ``(nn.ReLU, nn.Sigmoid)``.
"""
def __init__(
self,
channels: list[int],
ratio: int = 16,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | tuple[Callable[..., nn.Module], Callable[..., nn.Module]] = (
nn.ReLU,
nn.Sigmoid,
),
) -> None:
super().__init__()
if callable(activation):
activation = (activation, activation)
if len(activation) != 2:
msg = "activation must be a callable or a tuple of callables of length 2."
raise ValueError(msg)
self.channels = channels
total_channel = sum(channels)
self.conv1 = Conv2dModule(
in_channels=total_channel,
out_channels=int(total_channel / ratio),
kernel_size=1,
stride=1,
normalization=build_norm_layer(normalization, num_features=int(total_channel / ratio)),
activation=build_activation_layer(activation[0]),
)
self.conv2 = Conv2dModule(
in_channels=int(total_channel / ratio),
out_channels=total_channel,
kernel_size=1,
stride=1,
normalization=build_norm_layer(normalization, num_features=total_channel),
activation=build_activation_layer(activation[1]),
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Forward."""
min_size = [int(_) for _ in x[-1].size()[-2:]]
out = [functional.adaptive_avg_pool2d(s, min_size) for s in x[:-1]] + [x[-1]]
out = torch.cat(out, dim=1)
out = self.conv1(out)
out = self.conv2(out)
out = torch.split(out, self.channels, dim=1)
return [s * functional.interpolate(a, size=s.size()[-2:], mode="nearest") for s, a in zip(x, out)]
class SpatialWeighting(nn.Module):
"""Spatial weighting.
Args:
channels (int): Number of input channels.
ratio (int): Reduction ratio for the bottleneck block. Default: 16.
activation (Callable[..., nn.Module] | tuple[Callable[..., nn.Module], Callable[..., nn.Module]]): \
Activation layer module or a tuple of activation layer modules.
If a single module is provided, it will be used for both activation layers.
Defaults to ``(nn.ReLU, nn.Sigmoid)``.
Raises:
ValueError: activation must be a callable or a tuple of callables of length 2.
TypeError: If activation is not a callable or a tuple of callables.
"""
def __init__(
self,
channels: int,
ratio: int = 16,
activation: Callable[..., nn.Module] | tuple[Callable[..., nn.Module], Callable[..., nn.Module]] = (
nn.ReLU,
nn.Sigmoid,
),
**kwargs,
) -> None:
super().__init__()
if callable(activation):
activation = (activation, activation)
if len(activation) != 2:
msg = "activation must be a callable or a tuple of callables of length 2."
raise ValueError(msg)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = Conv2dModule(
in_channels=channels,
out_channels=int(channels / ratio),
kernel_size=1,
stride=1,
activation=build_activation_layer(activation[0]),
)
self.conv2 = Conv2dModule(
in_channels=int(channels / ratio),
out_channels=channels,
kernel_size=1,
stride=1,
activation=build_activation_layer(activation[1]),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
out = self.global_avgpool(x)
out = self.conv1(out)
out = self.conv2(out)
return x * out
class SpatialWeightingV2(nn.Module):
"""SpatialWeightingV2.
The original repo: https://github.com/DeLightCMU/PSA.
Args:
channels (int): Number of input channels.
ratio (int): Reduction ratio of internal channels.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
enable_norm (bool): Whether to enable normalization layers.
"""
def __init__(
self,
channels: int,
ratio: int = 16,
normalization: Callable[..., nn.Module] | None = None,
enable_norm: bool = False,
) -> None:
super().__init__()
self.in_channels = channels
self.internal_channels = int(channels / ratio)
# channel-only branch
self.v_channel = Conv2dModule(
in_channels=self.in_channels,
out_channels=self.internal_channels,
kernel_size=1,
stride=1,
bias=False,
normalization=build_norm_layer(normalization, num_features=self.internal_channels) if enable_norm else None,
activation=None,
)
self.q_channel = Conv2dModule(
in_channels=self.in_channels,
out_channels=1,
kernel_size=1,
stride=1,
bias=False,
normalization=build_norm_layer(normalization, num_features=1) if enable_norm else None,
activation=None,
)
self.out_channel = Conv2dModule(
in_channels=self.internal_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
normalization=build_norm_layer(normalization, num_features=self.in_channels),
activation=build_activation_layer(nn.Sigmoid),
)
# spatial-only branch
self.v_spatial = Conv2dModule(
in_channels=self.in_channels,
out_channels=self.internal_channels,
kernel_size=1,
stride=1,
bias=False,
normalization=build_norm_layer(normalization, num_features=self.internal_channels) if enable_norm else None,
activation=None,
)
self.q_spatial = Conv2dModule(
in_channels=self.in_channels,
out_channels=self.internal_channels,
kernel_size=1,
stride=1,
bias=False,
normalization=build_norm_layer(normalization, num_features=self.internal_channels) if enable_norm else None,
activation=None,
)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
def _channel_weighting(self, x: torch.Tensor) -> torch.Tensor:
"""_channel_weighting.
Args:
x (torch.Tensor): input tensor.
Returns:
torch.Tensor: output tensor.
"""
h, w = (int(_) for _ in x.size()[-2:])
v = self.v_channel(x).view(-1, self.internal_channels, h * w)
q = self.q_channel(x).view(-1, h * w, 1)
q = torch.softmax(q, dim=1)
y = torch.matmul(v, q)
y = y.view(-1, self.internal_channels, 1, 1)
y = self.out_channel(y)
return x * y
def _spatial_weighting(self, x: torch.Tensor) -> torch.Tensor:
"""_spatial_weighting.
Args:
x (torch.Tensor): input tensor.
Returns:
torch.Tensor: output tensor.
"""
h, w = (int(_) for _ in x.size()[-2:])
v = self.v_spatial(x)
v = v.view(-1, self.internal_channels, h * w)
q = self.q_spatial(x)
q = self.global_avgpool(q)
q = torch.softmax(q, dim=1)
q = q.view(-1, 1, self.internal_channels)
y = torch.matmul(q, v)
y = y.view(-1, 1, h, w)
y = torch.sigmoid(y)
return x * y
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
y_channel = self._channel_weighting(x)
y_spatial = self._spatial_weighting(x)
return y_channel + y_spatial
class ConditionalChannelWeighting(nn.Module):
"""Conditional channel weighting module.
Args:
in_channels (list[int]): Number of input channels for each input feature map.
stride (int): Stride used in the first convolutional layer.
reduce_ratio (int): Reduction ratio used in the cross-resolution weighting module.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``nn.BatchNorm2d``.
with_cp (bool): Whether to use checkpointing to save memory.
dropout (float | None): Dropout probability used in the depthwise convolutional layers.
weighting_module_version (str): Version of the spatial weighting module to use.
neighbour_weighting (bool): Whether to use the neighbour support module.
dw_ksize (int): Kernel size used in the depthwise convolutional layers.
Raises:
ValueError: If stride is not 1 or 2.
"""
def __init__(
self,
in_channels: list[int],
stride: int,
reduce_ratio: int,
normalization: Callable[..., nn.Module] = nn.BatchNorm2d,
with_cp: bool = False,
dropout: float | None = None,
weighting_module_version: str = "v1",
neighbour_weighting: bool = False,
dw_ksize: int = 3,
) -> None:
super().__init__()
self.with_cp = with_cp
self.stride = stride
if stride not in [1, 2]:
msg = "stride must be 1 or 2."
raise ValueError(msg)
spatial_weighting_module = SpatialWeighting if weighting_module_version == "v1" else SpatialWeightingV2
branch_channels = [channel // 2 for channel in in_channels]
self.cross_resolution_weighting = CrossResolutionWeighting(
branch_channels,
ratio=reduce_ratio,
normalization=normalization,
)
self.depthwise_convs = nn.ModuleList(
[
Conv2dModule(
channel,
channel,
kernel_size=dw_ksize,
stride=self.stride,
padding=dw_ksize // 2,
groups=channel,
normalization=build_norm_layer(normalization, num_features=channel),
activation=None,
)
for channel in branch_channels
],
)
self.spatial_weighting = nn.ModuleList(
[
spatial_weighting_module( # type: ignore[call-arg]
channels=channel,
ratio=4,
normalization=normalization,
enable_norm=True,
)
for channel in branch_channels
],
)
self.neighbour_weighting = None
if neighbour_weighting:
self.neighbour_weighting = nn.ModuleList(
[
NeighbourSupport(
channel,
kernel_size=3,
key_ratio=8,
value_ratio=4,
normalization=normalization,
)
for channel in branch_channels
],
)
self.dropout = None
if dropout is not None and dropout > 0.0:
self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in branch_channels])
def _inner_forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""_inner_forward.
Args:
x (torch.Tensor): Input tensor.
Returns:
list[torch.Tensor]: Output tensor.
"""
x = [s.chunk(2, dim=1) for s in x]
x1 = [s[0] for s in x]
x2 = [s[1] for s in x]
x2 = self.cross_resolution_weighting(x2)
x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
if self.neighbour_weighting is not None:
x2 = [nw(s) for s, nw in zip(x2, self.neighbour_weighting)]
x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]
if self.dropout is not None:
x2 = [dropout(s) for s, dropout in zip(x2, self.dropout)]
out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
return [channel_shuffle(s, 2) for s in out]
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Forward."""
return cp.checkpoint(self._inner_forward, x) if self.with_cp and x.requires_grad else self._inner_forward(x)
class Stem(nn.Module):
"""Stem.
Args:
in_channels (int): Number of input image channels. Typically 3.
stem_channels (int): Number of output channels of the stem layer.
out_channels (int): Number of output channels of the backbone network.
expand_ratio (int): Expansion ratio of the internal channels.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to ``nn.BatchNorm2d``.
with_cp (bool): Use checkpointing to save memory during forward pass.
num_stages (int): Number of stages in the backbone network.
strides (tuple[int, int]): Strides of the first and subsequent stages.
extra_stride (bool): Use an extra stride in the second stage.
input_norm (bool): Use instance normalization on the input image.
Raises:
TypeError: If strides is not a tuple or list.
ValueError: If len(strides) is not equal to num_stages + 1.
"""
def __init__(
self,
in_channels: int,
stem_channels: int = 32,
out_channels: int = 32,
expand_ratio: int = 1,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True),
with_cp: bool = False,
strides: tuple[int, int] = (2, 2),
extra_stride: bool = False,
input_norm: bool = False,
) -> None:
"""Stem initialization."""
super().__init__()
if not isinstance(strides, (tuple, list)):
msg = "strides must be tuple or list."
raise TypeError(msg)
if len(strides) != 2:
msg = "len(strides) must equal to 2."
raise ValueError(msg)
self.in_channels = in_channels
self.out_channels = out_channels
self.normalization = normalization
self.with_cp = with_cp
self.input_norm = None
if input_norm:
self.input_norm = nn.InstanceNorm2d(in_channels)
self.conv1 = Conv2dModule(
in_channels=in_channels,
out_channels=stem_channels,
kernel_size=3,
stride=strides[0],
padding=1,
normalization=build_norm_layer(self.normalization, num_features=stem_channels),
activation=build_activation_layer(nn.ReLU),
)
self.conv2 = None
if extra_stride:
self.conv2 = Conv2dModule(
in_channels=stem_channels,
out_channels=stem_channels,
kernel_size=3,
stride=2,
padding=1,
normalization=build_norm_layer(self.normalization, num_features=stem_channels),
activation=build_activation_layer(nn.ReLU),
)
mid_channels = int(round(stem_channels * expand_ratio))
branch_channels = stem_channels // 2
if stem_channels == self.out_channels:
inc_channels = self.out_channels - branch_channels
else:
inc_channels = self.out_channels - stem_channels
self.branch1 = nn.Sequential(
Conv2dModule(
branch_channels,
branch_channels,
kernel_size=3,
stride=strides[1],
padding=1,
groups=branch_channels,
normalization=build_norm_layer(normalization, num_features=branch_channels),
activation=None,
),
Conv2dModule(
branch_channels,
inc_channels,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(normalization, num_features=inc_channels),
activation=build_activation_layer(nn.ReLU),
),
)
self.expand_conv = Conv2dModule(
branch_channels,
mid_channels,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(normalization, num_features=mid_channels),
activation=build_activation_layer(nn.ReLU),
)
self.depthwise_conv = Conv2dModule(
mid_channels,
mid_channels,
kernel_size=3,
stride=strides[1],
padding=1,
groups=mid_channels,
normalization=build_norm_layer(normalization, num_features=mid_channels),
activation=None,
)
self.linear_conv = Conv2dModule(
mid_channels,
branch_channels if stem_channels == self.out_channels else stem_channels,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(
normalization,
num_features=branch_channels if stem_channels == self.out_channels else stem_channels,
),
activation=build_activation_layer(nn.ReLU),
)
def _inner_forward(self, x: torch.Tensor) -> torch.Tensor:
"""_inner_forward.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
if self.input_norm is not None:
x = self.input_norm(x)
x = self.conv1(x)
if self.conv2 is not None:
x = self.conv2(x)
x1, x2 = x.chunk(2, dim=1)
x1 = self.branch1(x1)
x2 = self.expand_conv(x2)
x2 = self.depthwise_conv(x2)
x2 = self.linear_conv(x2)
out = torch.cat((x1, x2), dim=1)
return channel_shuffle(out, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
return cp.checkpoint(self._inner_forward, x) if self.with_cp and x.requires_grad else self._inner_forward(x)
class ShuffleUnit(nn.Module):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): Stride of the 3x3 convolution layer. Default: 1
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``nn.BatchNorm2d``.
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.ReLU``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
normalization: Callable[..., nn.Module] = nn.BatchNorm2d,
activation: Callable[..., nn.Module] = nn.ReLU,
with_cp: bool = False,
) -> None:
super().__init__()
self.stride = stride
self.with_cp = with_cp
branch_features = out_channels // 2
if self.stride == 1 and in_channels != branch_features * 2:
msg = (
f"in_channels ({in_channels}) should equal to "
f"branch_features * 2 ({branch_features * 2}) "
"when stride is 1"
)
raise ValueError(msg)
if self.stride > 1:
self.branch1 = nn.Sequential(
Conv2dModule(
in_channels,
in_channels,
kernel_size=3,
stride=self.stride,
padding=1,
groups=in_channels,
normalization=build_norm_layer(normalization, num_features=in_channels),
activation=None,
),
Conv2dModule(
in_channels,
branch_features,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(normalization, num_features=branch_features),
activation=build_activation_layer(activation),
),
)
self.branch2 = nn.Sequential(
Conv2dModule(
in_channels if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(normalization, num_features=branch_features),
activation=build_activation_layer(activation),
),
Conv2dModule(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
groups=branch_features,
normalization=build_norm_layer(normalization, num_features=branch_features),
activation=None,
),
Conv2dModule(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(normalization, num_features=branch_features),
activation=build_activation_layer(activation),
),
)
def _inner_forward(self, x: torch.Tensor) -> torch.Tensor:
"""_inner_forward."""
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
return channel_shuffle(out, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
return cp.checkpoint(self._inner_forward, x) if self.with_cp and x.requires_grad else self._inner_forward(x)
class LiteHRModule(nn.Module):
"""LiteHR module.
Args:
num_branches (int): Number of branches in the network.
num_blocks (int): Number of blocks in each branch.
in_channels (list[int]): List of input channels for each branch.
reduce_ratio (int): Reduction ratio for the weighting module.
module_type (str): Type of module to use for the network. Can be "LITE" or "NAIVE".
multiscale_output (bool, optional): Whether to output features from all branches. Defaults to False.
with_fuse (bool, optional): Whether to use the fuse layer. Defaults to True.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``nn.BatchNorm2d``.
with_cp (bool, optional): Whether to use checkpointing. Defaults to False.
dropout (float, optional): Dropout rate. Defaults to None.
weighting_module_version (str, optional): Version of the weighting module to use. Defaults to "v1".
neighbour_weighting (bool, optional): Whether to use neighbour weighting. Defaults to False.
"""
def __init__(
self,
num_branches: int,
num_blocks: int,
in_channels: list[int],
reduce_ratio: int,
module_type: str,
multiscale_output: bool = False,
with_fuse: bool = True,
normalization: Callable[..., nn.Module] = nn.BatchNorm2d,
with_cp: bool = False,
dropout: float | None = None,
weighting_module_version: str = "v1",
neighbour_weighting: bool = False,
) -> None:
super().__init__()
self._check_branches(num_branches, in_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.module_type = module_type
self.multiscale_output = multiscale_output
self.with_fuse = with_fuse
self.normalization = normalization
self.with_cp = with_cp
self.weighting_module_version = weighting_module_version
self.neighbour_weighting = neighbour_weighting
if self.module_type == "LITE":
self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio, dropout=dropout)
elif self.module_type == "NAIVE":
self.layers = self._make_naive_branches(num_branches, num_blocks)
if self.with_fuse:
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU()
@staticmethod
def _check_branches(num_branches: int, in_channels: list[int]) -> None:
"""Check input to avoid ValueError."""
if num_branches != len(in_channels):
error_msg = f"NUM_BRANCHES({num_branches}) != NUM_INCHANNELS({len(in_channels)})"
raise ValueError(error_msg)
def _make_weighting_blocks(
self,
num_blocks: int,
reduce_ratio: int,
stride: int = 1,
dropout: float | None = None,
) -> nn.Sequential:
layers = [
ConditionalChannelWeighting(
self.in_channels,
stride=stride,
reduce_ratio=reduce_ratio,
normalization=self.normalization,
with_cp=self.with_cp,
dropout=dropout,
weighting_module_version=self.weighting_module_version,
neighbour_weighting=self.neighbour_weighting,
)
for _ in range(num_blocks)
]
return nn.Sequential(*layers)
def _make_one_branch(self, branch_index: int, num_blocks: int, stride: int = 1) -> nn.Sequential:
"""Make one branch."""
layers = [
ShuffleUnit(
self.in_channels[branch_index],
self.in_channels[branch_index],
stride=stride,
normalization=self.normalization,
activation=nn.ReLU,
with_cp=self.with_cp,
),
] + [
ShuffleUnit(
self.in_channels[branch_index],
self.in_channels[branch_index],
stride=1,
normalization=self.normalization,
activation=nn.ReLU,
with_cp=self.with_cp,
)
for _ in range(1, num_blocks)
]
return nn.Sequential(*layers)
def _make_naive_branches(self, num_branches: int, num_blocks: int) -> nn.ModuleList:
"""Make branches."""
branches = [self._make_one_branch(i, num_blocks) for i in range(num_branches)]
return nn.ModuleList(branches)
def _make_fuse_layers(self) -> nn.ModuleList:
"""Make fuse layer."""
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
num_out_branches = num_branches if self.multiscale_output else 1
fuse_layers = []
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
nn.Conv2d(
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
build_norm_layer(self.normalization, in_channels[i])[1],
),
)
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
nn.Conv2d(
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
groups=in_channels[j],
bias=False,
),
build_norm_layer(self.normalization, in_channels[j])[1],
nn.Conv2d(
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
build_norm_layer(self.normalization, in_channels[i])[1],
),
)
else:
conv_downsamples.append(
nn.Sequential(
nn.Conv2d(
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
groups=in_channels[j],
bias=False,
),
build_norm_layer(self.normalization, in_channels[j])[1],
nn.Conv2d(
in_channels[j],
in_channels[j],
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
build_norm_layer(self.normalization, in_channels[j])[1],
nn.ReLU(inplace=True),
),
)
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Forward function."""
if self.num_branches == 1:
return [self.layers[0](x[0])]
if self.module_type == "LITE":
out = self.layers(x)
elif self.module_type == "NAIVE":
for i in range(self.num_branches):
x[i] = self.layers[i](x[i])
out = x
if self.with_fuse:
out_fuse = []
for i in range(len(self.fuse_layers)):
y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
for j in range(self.num_branches):
fuse_y = out[j] if i == j else self.fuse_layers[i][j](out[j])
if fuse_y.size()[-2:] != y.size()[-2:]:
fuse_y = functional.interpolate(fuse_y, size=y.size()[-2:], mode="nearest")
y += fuse_y
out_fuse.append(self.relu(y))
out = out_fuse
elif not self.multiscale_output:
out = [out[0]]
return out
class LiteHRNetModule(nn.Module):
"""Lite-HRNet backbone.
`High-Resolution Representations for Labeling Pixels and Regions
<https://arxiv.org/abs/1904.04514>`_
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Default: 3.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``nn.BatchNorm2d``.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
"""
def __init__(
self,
num_stages: int,
stem_configuration: dict[str, Any],
stages_spec: dict[str, Any],
in_channels: int = 3,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True),
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = False,
dropout: float | None = None,
pretrained_weights: str | None = None,
) -> None:
"""Init."""
super().__init__()
self.normalization = normalization
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
self.stem = Stem(in_channels=in_channels, **stem_configuration, normalization=normalization)
self.num_stages = num_stages
self.stages_spec = stages_spec
num_channels_last = [
self.stem.out_channels,
]
for i in range(self.num_stages):
num_channels = self.stages_spec["num_channels"][i]
num_channels = [num_channels[i] for i in range(len(num_channels))]
setattr(
self,
f"transition{i}",
self._make_transition_layer(num_channels_last, num_channels),
)
stage, num_channels_last = self._make_stage(
self.stages_spec,
i,
num_channels,
multiscale_output=True,
dropout=dropout,
)
setattr(self, f"stage{i}", stage)
if pretrained_weights is not None:
self.load_pretrained_weights(pretrained_weights, prefix="backbone")
def _make_transition_layer(
self,
num_channels_pre_layer: list[int],
num_channels_cur_layer: list[int],
) -> nn.ModuleList:
"""Make transition layer."""
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
nn.Conv2d(
num_channels_pre_layer[i],
num_channels_pre_layer[i],
kernel_size=3,
stride=1,
padding=1,
groups=num_channels_pre_layer[i],
bias=False,
),
build_norm_layer(self.normalization, num_channels_pre_layer[i])[1],
nn.Conv2d(
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
build_norm_layer(self.normalization, num_channels_cur_layer[i])[1],
nn.ReLU(),
),
)
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=1,
groups=in_channels,
bias=False,
),
build_norm_layer(self.normalization, in_channels)[1],
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
build_norm_layer(self.normalization, out_channels)[1],
nn.ReLU(),
),
)
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_stage(
self,
stages_spec: dict,
stage_index: int,
in_channels: list[int],
multiscale_output: bool = True,
dropout: float | None = None,
) -> tuple[nn.Module, list[int]]:
"""Create a stage of the LiteHRNet backbone.
Args:
stages_spec (dict): Specification of the stages of the backbone.
stage_index (int): Index of the current stage.
in_channels (list[int]): List of input channels for each branch.
multiscale_output (bool, optional): Whether to output features from all branches. Defaults to True.
dropout (float | None, optional): Dropout probability. Defaults to None.
Returns:
tuple[nn.Module, list[int]]: A tuple containing the stage module and the output channels for each branch.
"""
num_modules = stages_spec["num_modules"][stage_index]
num_branches = stages_spec["num_branches"][stage_index]
num_blocks = stages_spec["num_blocks"][stage_index]
reduce_ratio = stages_spec["reduce_ratios"][stage_index]
with_fuse = stages_spec["with_fuse"][stage_index]
module_type = stages_spec["module_type"][stage_index]
weighting_module_version = stages_spec.get("weighting_module_version", "v1")
neighbour_weighting = stages_spec.get("neighbour_weighting", False)
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
reset_multiscale_output = not ((not multiscale_output) and i == num_modules - 1)
modules.append(
LiteHRModule(
num_branches,
num_blocks,
in_channels,
reduce_ratio,
module_type,
multiscale_output=reset_multiscale_output,
with_fuse=with_fuse,
normalization=self.normalization,
with_cp=self.with_cp,
dropout=dropout,
weighting_module_version=weighting_module_version,
neighbour_weighting=neighbour_weighting,
),
)
in_channels = modules[-1].in_channels
return nn.Sequential(*modules), in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
stem_outputs = self.stem(x)
y = stem_outputs
y_list = [y]
for i in range(self.num_stages):
transition_modules = getattr(self, f"transition{i}")
stage_inputs = []
for j in range(self.stages_spec["num_branches"][i]):
if transition_modules[j]:
if j >= len(y_list):
stage_inputs.append(transition_modules[j](y_list[-1]))
else:
stage_inputs.append(transition_modules[j](y_list[j]))
else:
stage_inputs.append(y_list[j])
stage_module = getattr(self, f"stage{i}")
y_list = stage_module(stage_inputs)
return y_list
def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "") -> None:
"""Initialize weights."""
checkpoint = None
if isinstance(pretrained, str) and Path(pretrained).exists():
checkpoint = torch.load(pretrained, "cpu")
print(f"init weight - {pretrained}")
elif pretrained is not None:
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir)
print(f"init weight - {pretrained}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint, prefix=prefix)
[docs]
class LiteHRNetBackbone:
"""LiteHRNet backbone factory."""
LITEHRNET_CFG: ClassVar[dict[str, Any]] = {
"lite_hrnet_s": {
"stem_configuration": {"extra_stride": True},
"num_stages": 2,
"stages_spec": {
"num_modules": [4, 4],
"num_branches": [2, 3],
"num_blocks": [2, 2],
"module_type": ["LITE", "LITE"],
"with_fuse": [True, True],
"reduce_ratios": [8, 8],
"num_channels": [[60, 120], [60, 120, 240]],
},
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetsv2_imagenet1k_rsc.pth",
},
"lite_hrnet_18": {
"stem_configuration": {},
"num_stages": 3,
"stages_spec": {
"num_modules": [2, 4, 2],
"num_branches": [2, 3, 4],
"num_blocks": [2, 2, 2],
"module_type": ["LITE", "LITE", "LITE"],
"with_fuse": [True, True, True],
"reduce_ratios": [8, 8, 8],
"num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
},
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnet18_imagenet1k_rsc.pth",
},
"lite_hrnet_x": {
"stem_configuration": {"stem_channels": 60, "out_channels": 60, "strides": (2, 1)},
"num_stages": 4,
"stages_spec": {
"weighting_module_version": "v1",
"num_modules": [2, 4, 4, 2],
"num_branches": [2, 3, 4, 5],
"num_blocks": [2, 2, 2, 2],
"module_type": ["LITE", "LITE", "LITE", "LITE"],
"with_fuse": [True, True, True, True],
"reduce_ratios": [2, 4, 8, 8],
"num_channels": [[18, 60], [18, 60, 80], [18, 60, 80, 160], [18, 60, 80, 160, 320]],
},
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetxv3_imagenet1k_rsc.pth",
},
}
def __new__(cls, model_name: str) -> LiteHRNetModule:
"""Constructor for LiteHRNet backbone."""
if model_name not in cls.LITEHRNET_CFG:
msg = f"model type '{model_name}' is not supported"
raise KeyError(msg)
return LiteHRNetModule(**cls.LITEHRNET_CFG[model_name])