Source code for otx.algo.segmentation.heads.fcn_head

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Custom FCNHead modules for OTX segmentation model."""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any, Callable, ClassVar

import torch
from torch import Tensor, nn

from otx.algo.modules import Conv2dModule, build_activation_layer
from otx.algo.modules.norm import build_norm_layer
from otx.algo.segmentation.modules import IterativeAggregator

from .base_segm_head import BaseSegmentationHead

if TYPE_CHECKING:
    from pathlib import Path


class FCNHeadModule(BaseSegmentationHead):
    """Fully Convolution Networks for Semantic Segmentation with aggregation.

    This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.

    Args:
        normalization (Callable[..., nn.Module] | None): Normalization layer module.
            Defaults to None.
        num_convs (int): Number of convs in the head. Default: 2.
        kernel_size (int): The kernel size for convs in the head. Default: 3.
        concat_input (bool): Whether concat the input and output of convs
            before classification layer.
        dilation (int): The dilation rate for convs in the head. Default: 1.
    """

    def __init__(
        self,
        in_channels: list[int] | int,
        in_index: list[int] | int,
        channels: int,
        normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True),
        input_transform: str | None = None,
        num_classes: int = 80,
        num_convs: int = 1,
        kernel_size: int = 1,
        concat_input: bool = False,
        dilation: int = 1,
        enable_aggregator: bool = False,
        aggregator_min_channels: int = 0,
        aggregator_merge_norm: str | None = None,
        aggregator_use_concat: bool = False,
        align_corners: bool = False,
        dropout_ratio: float = -1,
        activation: Callable[..., nn.Module] | None = nn.ReLU,
        pretrained_weights: Path | str | None = None,
    ) -> None:
        """Initialize a Fully Convolution Networks head."""
        if not isinstance(dilation, int):
            msg = f"dilation should be int, but got {type(dilation)}"
            raise TypeError(msg)
        if num_convs < 0 and dilation <= 0:
            msg = "num_convs and dilation should be larger than 0"
            raise ValueError(msg)

        self.num_convs = num_convs
        self.concat_input = concat_input
        self.kernel_size = kernel_size

        if enable_aggregator:  # Lite-HRNet aggregator
            if in_channels is None or isinstance(in_channels, int):
                msg = "'in_channels' should be List[int]."
                raise ValueError(msg)
            aggregator = IterativeAggregator(
                in_channels=in_channels,
                min_channels=aggregator_min_channels,
                normalization=normalization,
                merge_norm=aggregator_merge_norm,
                use_concat=aggregator_use_concat,
            )

            aggregator_min_channels = aggregator_min_channels if aggregator_min_channels is not None else 0
            # change arguments temporarily
            in_channels = max(in_channels[0], aggregator_min_channels)
            input_transform = None
            if isinstance(in_index, list):
                in_index = in_index[0]
        else:
            aggregator = None

        super().__init__(
            in_index=in_index,
            normalization=normalization,
            input_transform=input_transform,
            in_channels=in_channels,
            align_corners=align_corners,
            dropout_ratio=dropout_ratio,
            channels=channels,
            num_classes=num_classes,
            activation=activation,
            pretrained_weights=pretrained_weights,
        )

        self.aggregator = aggregator

        if num_convs == 0 and (self.in_channels != self.channels):
            msg = "in_channels and channels should be equal when num_convs is 0"
            raise ValueError(msg)

        conv_padding = (kernel_size // 2) * dilation
        convs = [
            Conv2dModule(
                self.in_channels,
                self.channels,
                kernel_size=kernel_size,
                padding=conv_padding,
                dilation=dilation,
                normalization=build_norm_layer(self.normalization, num_features=self.channels),
                activation=build_activation_layer(self.activation),
            ),
        ]
        convs.extend(
            [
                Conv2dModule(
                    self.channels,
                    self.channels,
                    kernel_size=kernel_size,
                    padding=conv_padding,
                    dilation=dilation,
                    normalization=build_norm_layer(self.normalization, num_features=self.channels),
                    activation=build_activation_layer(self.activation),
                )
                for _ in range(num_convs - 1)
            ],
        )
        if num_convs == 0:
            self.convs = nn.Identity()
        else:
            self.convs = nn.Sequential(*convs)
        if self.concat_input:
            self.conv_cat = Conv2dModule(
                self.in_channels + self.channels,
                self.channels,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                normalization=build_norm_layer(self.normalization, num_features=self.channels),
                activation=build_activation_layer(self.activation),
            )

        if self.activation:
            self.convs[-1].with_activation = False
            delattr(self.convs[-1], "activation")  # why we delete last activation?

    def _forward_feature(self, inputs: Tensor) -> Tensor:
        """Forward function for feature maps.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """
        x = self._transform_inputs(inputs)
        feats = self.convs(x)
        if self.concat_input:
            feats = self.conv_cat(torch.cat([x, feats], dim=1))
        return feats

    def forward(self, inputs: Tensor) -> Tensor:
        """Forward function."""
        output = self._forward_feature(inputs)
        return self.cls_seg(output)

    def _transform_inputs(self, inputs: list[Tensor]) -> Tensor | list:
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """
        return self.aggregator(inputs)[0] if self.aggregator is not None else super()._transform_inputs(inputs)


[docs] class FCNHead: """FCNHead factory for segmentation.""" FCNHEAD_CFG: ClassVar[dict[str, Any]] = { "lite_hrnet_s": { "in_channels": [60, 120, 240], "in_index": [0, 1, 2], "input_transform": "multiple_select", "channels": 60, "enable_aggregator": True, "aggregator_merge_norm": "None", "aggregator_use_concat": False, }, "lite_hrnet_18": { "in_channels": [40, 80, 160, 320], "in_index": [0, 1, 2, 3], "input_transform": "multiple_select", "channels": 40, "enable_aggregator": True, }, "lite_hrnet_x": { "in_channels": [18, 60, 80, 160, 320], "in_index": [0, 1, 2, 3, 4], "input_transform": "multiple_select", "channels": 60, "enable_aggregator": True, "aggregator_min_channels": 60, "aggregator_merge_norm": "None", "aggregator_use_concat": False, }, "dinov2-small-seg": { "in_channels": [384, 384, 384, 384], "in_index": [0, 1, 2, 3], "input_transform": "resize_concat", "channels": 1536, "pretrained_weights": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_linear_head.pth", }, } def __new__(cls, model_name: str, num_classes: int) -> FCNHeadModule: """Constructor for FCNHead.""" if model_name not in cls.FCNHEAD_CFG: msg = f"model type '{model_name}' is not supported" raise KeyError(msg) normalization = ( partial(build_norm_layer, nn.SyncBatchNorm, requires_grad=True) if model_name == "dinov2-small-seg" else partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True) ) return FCNHeadModule(**cls.FCNHEAD_CFG[model_name], num_classes=num_classes, normalization=normalization)