Source code for otx.algo.detection.heads.yolo_head

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Head implementation of YOLOv7 and YOLOv9.

Reference : https://github.com/WongKinYiu/YOLO
"""

from __future__ import annotations  # noqa: I001

from typing import Any, ClassVar, NoReturn

import torch
from einops import rearrange
from torch import Tensor, nn
from torchvision.ops import batched_nms

from otx.algo.common.utils.nms import multiclass_nms
from otx.algo.detection.heads.base_head import BaseDenseHead
from otx.algo.detection.layers import AConv, ADown, Concat, SPPELAN, RepNCSPELAN
from otx.algo.detection.utils.utils import round_up, set_info_into_instance, auto_pad
from otx.algo.modules import Conv2dModule
from otx.algo.utils.mmengine_utils import InstanceData
from otx.core.data.entity.base import OTXBatchDataEntity
from otx.core.data.entity.detection import DetBatchDataEntity


class Anchor2Vec(nn.Module):
    """Convert anchor tensor to vector tensor.

    Args:
        reg_max (int): Maximum number of anchor regions.
    """

    def __init__(self, reg_max: int = 16) -> None:
        super().__init__()
        reverse_reg = torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1, 1)
        self.anc2vec = nn.Conv3d(in_channels=reg_max, out_channels=1, kernel_size=1, bias=False)
        self.anc2vec.weight = nn.Parameter(reverse_reg, requires_grad=False)

    def forward(self, anchor_x: Tensor) -> Tensor:
        """Forward function."""
        anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4)
        vector_x = anchor_x.softmax(dim=1)
        vector_x = self.anc2vec(vector_x)[:, 0]
        return anchor_x, vector_x


class CBLinear(nn.Module):
    """Convolutional block that outputs multiple feature maps split along the channel dimension.

    Args:
        in_channels (int): Number of input channels.
        out_channels (list[int]): Number of output channels.
        kernel_size (int): Size of the convolving kernel.
    """

    def __init__(self, in_channels: int, out_channels: list[int], kernel_size: int = 1, **kwargs) -> None:
        super().__init__()
        kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
        self.conv = nn.Conv2d(in_channels, sum(out_channels), kernel_size, **kwargs)
        self.out_channels = list(out_channels)

    def forward(self, x: Tensor) -> Tensor:
        """Forward function."""
        x = self.conv(x)
        return x.split(self.out_channels, dim=1)


class CBFuse(nn.Module):
    """Fuse the feature maps from the previous layer with the feature maps from the current layer.

    Args:
        index (list[int]): Index of the feature maps from the previous layer.
        mode (str): Mode of the interpolation operation.
    """

    def __init__(self, index: list[int], mode: str = "nearest") -> None:
        super().__init__()
        self.idx = index
        self.mode = mode

    def forward(self, x_list: list[tuple[Tensor, ...] | Tensor]) -> Tensor:
        """Forward function."""
        target: Tensor = x_list[-1]
        target_size = target.shape[2:]  # Batch, Channel, H, W

        res = [
            nn.functional.interpolate(x[pick_id], size=target_size, mode=self.mode)
            for pick_id, x in zip(self.idx, x_list)
        ]
        return torch.stack([*res, target]).sum(dim=0)


class ImplicitA(nn.Module):
    """Implement YOLOR - implicit knowledge(Add).

    paper: https://arxiv.org/abs/2105.04206

    Args:
        channel (int): Number of input channels.
        mean (float): Mean of the normal distribution.
        std (float): Standard deviation of the normal distribution.
    """

    def __init__(self, channel: int, mean: float = 0.0, std: float = 0.02) -> None:
        super().__init__()
        self.channel = channel
        self.mean = mean
        self.std = std

        self.implicit = nn.Parameter(torch.empty(1, channel, 1, 1))
        nn.init.normal_(self.implicit, mean=mean, std=self.std)

    def forward(self, x: Tensor) -> Tensor:
        """Forward function."""
        return self.implicit + x


class ImplicitM(nn.Module):
    """Implement YOLOR - implicit knowledge(multiply).

    paper: https://arxiv.org/abs/2105.04206

    Args:
        channel (int): Number of input channels.
        mean (float): Mean of the normal distribution.
        std (float): Standard deviation of the normal distribution.
    """

    def __init__(self, channel: int, mean: float = 1.0, std: float = 0.02) -> None:
        super().__init__()
        self.channel = channel
        self.mean = mean
        self.std = std

        self.implicit = nn.Parameter(torch.empty(1, channel, 1, 1))
        nn.init.normal_(self.implicit, mean=self.mean, std=self.std)

    def forward(self, x: Tensor) -> Tensor:
        """Forward function."""
        return self.implicit * x


class SingleHeadDetectionforYOLOv9(nn.Module):
    """A single YOLO Detection head for YOLOv9 detection models.

    Args:
        in_channels (tuple[int, int]): Number of input channels.
        num_classes (int): Number of classes.
        reg_max (int): Maximum number of anchor regions.
        use_group (bool): Whether to use group convolution.
    """

    def __init__(
        self,
        in_channels: tuple[int, int],
        num_classes: int,
        *,
        reg_max: int = 16,
        use_group: bool = True,
    ) -> None:
        super().__init__()

        groups = 4 if use_group else 1
        anchor_channels = 4 * reg_max

        first_neck, first_channels = in_channels
        anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, reg_max)
        class_neck = max(first_neck, min(num_classes * 2, 128))

        self.anchor_conv = nn.Sequential(
            Conv2dModule(
                first_channels,
                anchor_neck,
                3,
                padding=auto_pad(3),
                normalization=nn.BatchNorm2d(anchor_neck, eps=1e-3, momentum=3e-2),
                activation=nn.SiLU(inplace=True),
            ),
            Conv2dModule(
                anchor_neck,
                anchor_neck,
                3,
                padding=auto_pad(3),
                groups=groups,
                normalization=nn.BatchNorm2d(anchor_neck, eps=1e-3, momentum=3e-2),
                activation=nn.SiLU(inplace=True),
            ),
            nn.Conv2d(anchor_neck, anchor_channels, 1, groups=groups),
        )
        self.class_conv = nn.Sequential(
            Conv2dModule(
                first_channels,
                class_neck,
                3,
                padding=auto_pad(3),
                normalization=nn.BatchNorm2d(class_neck, eps=1e-3, momentum=3e-2),
                activation=nn.SiLU(inplace=True),
            ),
            Conv2dModule(
                class_neck,
                class_neck,
                3,
                padding=auto_pad(3),
                normalization=nn.BatchNorm2d(class_neck, eps=1e-3, momentum=3e-2),
                activation=nn.SiLU(inplace=True),
            ),
            nn.Conv2d(class_neck, num_classes, 1),
        )

        self.anc2vec = Anchor2Vec(reg_max=reg_max)

        self.anchor_conv[-1].bias.data.fill_(1.0)
        self.class_conv[-1].bias.data.fill_(-10)  # TODO (author): math.log(5 * 4 ** idx / 80 ** 3)

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
        """Forward function."""
        anchor_x = self.anchor_conv(x)
        class_x = self.class_conv(x)
        anchor_x, vector_x = self.anc2vec(anchor_x)
        return class_x, anchor_x, vector_x


class SingleHeadDetectionforYOLOv7(nn.Module):
    """A single YOLO Detection head for YOLOv7 detection models.

    Args:
        in_channels (int | tuple[int, int]): Number of input channels.
        num_classes (int): Number of classes.
        anchor_num (int): Number of anchors. Default is 3.
    """

    def __init__(
        self,
        in_channels: int | tuple[int, int],
        num_classes: int,
        *args,
        anchor_num: int = 3,
        **kwargs,
    ) -> None:
        super().__init__()

        if isinstance(in_channels, tuple):
            in_channels = in_channels[1]

        out_channel = num_classes + 5
        out_channels = out_channel * anchor_num
        self.head_conv = nn.Conv2d(in_channels, out_channels, 1)

        self.implicit_a = ImplicitA(in_channels)
        self.implicit_m = ImplicitM(out_channels)

    def forward(self, x: Tensor) -> Tensor:
        """Forward function."""
        x = self.implicit_a(x)
        x = self.head_conv(x)
        return self.implicit_m(x)


class MultiheadDetection(nn.Module):
    """Mutlihead Detection module for Dual detect or Triple detect.

    Args:
        in_channels (list[int]): Number of input channels.
        num_classes (int): Number of classes.
    """

    def __init__(self, in_channels: list[int], num_classes: int, **head_kwargs) -> None:
        super().__init__()
        single_head_detection: nn.Module = (
            SingleHeadDetectionforYOLOv7 if head_kwargs.pop("version", None) == "v7" else SingleHeadDetectionforYOLOv9
        )

        self.heads = nn.ModuleList(
            [
                single_head_detection((in_channels[0], in_channel), num_classes, **head_kwargs)
                for in_channel in in_channels
            ],
        )

    def forward(self, x_list: list[Tensor]) -> list[Tensor]:
        """Forward function."""
        return [head(x) for x, head in zip(x_list, self.heads)]


class YOLOHeadModule(BaseDenseHead):
    """Head for YOLOv7 and v9.

    Args:
        num_classes (int): Number of classes.
        csp_channels (list[list[int]]): List of channels for CSP blocks.
        concat_sources (list[str | int]): List of sources to concatenate.
        aconv_channels (list[list[int]], optional): List of channels for AConv. Defaults to None.
        adown_channels (list[list[int]], optional): List of channels for ADown. Defaults to None.
        pre_upsample_concat_cfg (dict[str, Any], optional): Configuration for pre-upsampling. Defaults to None.
        csp_args (dict[str, Any], optional): Arguments for CSP blocks. Defaults to None.
        aux_cfg (dict[str, Any], optional): Configuration for auxiliary head. Defaults to None.
        with_nms (bool, optional): Whether to use NMS. Defaults to True.
        min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.1.
        min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.65.
    """

    def __init__(
        self,
        num_classes: int,
        csp_channels: list[list[int]],
        concat_sources: list[str | int],
        aconv_channels: list[list[int]] | None = None,
        adown_channels: list[list[int]] | None = None,
        pre_upsample_concat_cfg: dict[str, Any] | None = None,
        csp_args: dict[str, Any] | None = None,
        aux_cfg: dict[str, Any] | None = None,
        with_nms: bool = True,
        min_confidence: float = 0.1,
        min_iou: float = 0.65,
    ) -> None:
        if len(csp_channels) - 1 != len(concat_sources):
            msg = (
                f"len(csp_channels) - 1 ({len(csp_channels) - 1}) "
                f"and len(concat_sources) ({len(concat_sources)}) should be the same."
            )
            raise ValueError(msg)

        super().__init__()
        self.num_classes = num_classes
        self.csp_channels = csp_channels
        self.aconv_channels = aconv_channels
        self.concat_sources = concat_sources
        self.pre_upsample_concat_cfg = pre_upsample_concat_cfg
        self.csp_args = csp_args or {}
        self.aux_cfg = aux_cfg
        self.with_nms = with_nms
        self.min_confidence = min_confidence
        self.min_iou = min_iou

        self.module = nn.ModuleList()
        if pre_upsample_concat_cfg:
            # for yolov9_s
            self.module.append(nn.Upsample(scale_factor=2, mode="nearest"))
            self.module.append(
                set_info_into_instance({"module": Concat(), "source": pre_upsample_concat_cfg.get("source")}),
            )

        output_channels: list[int] = []
        self.module.append(
            set_info_into_instance(
                {
                    "module": RepNCSPELAN(
                        csp_channels[0][0],
                        csp_channels[0][1],
                        part_channels=csp_channels[0][2],
                        csp_args=self.csp_args,
                    ),
                    "tags": "P3",
                },
            ),
        )
        output_channels.append(csp_channels[0][1])

        aconv_adown_channels = aconv_channels or adown_channels
        if aconv_adown_channels is None:
            msg = "Only one of aconv_channels or adown_channels should be provided."
            raise ValueError(msg)
        aconv_adown_object = AConv if aconv_channels else ADown
        for idx, (csp_channel, aconv_adown_channel, concat_source) in enumerate(
            zip(csp_channels[1:], aconv_adown_channels, concat_sources),
            start=4,
        ):
            self.module.append(aconv_adown_object(aconv_adown_channel[0], aconv_adown_channel[1]))
            self.module.append(set_info_into_instance({"module": Concat(), "source": concat_source}))
            self.module.append(
                set_info_into_instance(
                    {
                        "module": RepNCSPELAN(
                            csp_channel[0],
                            csp_channel[1],
                            part_channels=csp_channel[2],
                            csp_args=self.csp_args,
                        ),
                        "tags": f"P{idx}",
                    },
                ),
            )
            output_channels.append(csp_channel[1])

        self.module.append(
            set_info_into_instance(
                {
                    "module": MultiheadDetection(output_channels, num_classes),
                    "source": ["P3", "P4", "P5"],
                    "tags": "Main",
                    "output": True,
                },
            ),
        )

        if aux_cfg:
            aux_output_channels: list[int] = []
            if sppelan_channels := aux_cfg.get("sppelan_channels", None):
                # for yolov9_s
                self.module.append(
                    set_info_into_instance(
                        {"module": SPPELAN(sppelan_channels[0], sppelan_channels[1]), "source": "B5", "tags": "A5"},
                    ),
                )
                aux_output_channels.append(sppelan_channels[1])
                for idx, csp_channel in enumerate(aux_cfg.get("csp_channels", [])):
                    self.module.append(nn.Upsample(scale_factor=2, mode="nearest"))
                    self.module.append(set_info_into_instance({"module": Concat(), "source": [-1, f"B{4-idx}"]}))
                    self.module.append(
                        set_info_into_instance(
                            {
                                "module": RepNCSPELAN(
                                    csp_channel[0],
                                    csp_channel[1],
                                    part_channels=csp_channel[2],
                                    csp_args=self.csp_args,
                                ),
                                "tags": f"A{4-idx}",
                            },
                        ),
                    )
                    aux_output_channels.append(csp_channel[1])
                aux_output_channels = aux_output_channels[::-1]  # reverse channels

            elif cblinear_channels := aux_cfg.get("cblinear_channels", None):
                # for yolov9_m, c
                for idx, cblinear_channel in enumerate(cblinear_channels, start=3):
                    self.module.append(
                        set_info_into_instance(
                            {
                                "module": CBLinear(cblinear_channel[0], cblinear_channel[1]),
                                "source": f"B{idx}",
                                "tags": f"R{idx}",
                            },
                        ),
                    )

                aux_aconv_adown_channels = aux_cfg.get("aconv_channels", None) or aux_cfg.get("adown_channels", None)
                if aux_aconv_adown_channels is None:
                    msg = "Only one of aconv_channels or adown_channels should be provided."
                    raise ValueError(msg)

                aux_aconv_adown_object = AConv if aconv_channels else ADown
                for idx, (csp_channel, aux_aconv_adown_channel, cbfuse_index, cbfuse_source) in enumerate(
                    zip(
                        aux_cfg.get("csp_channels", []),
                        aux_aconv_adown_channels,
                        aux_cfg.get("cbfuse_indices", []),
                        aux_cfg.get("cbfuse_sources", []),
                    ),
                ):
                    if idx == 0 and len(aux_aconv_adown_channel) == 0 and len(cbfuse_index) == 0:
                        conv_channels: list[list[int]] = aux_cfg.get("conv_channels")  # type: ignore[assignment]
                        self.module.append(
                            set_info_into_instance(
                                {
                                    "module": Conv2dModule(
                                        conv_channels[0][0],
                                        conv_channels[0][1],
                                        3,
                                        stride=2,
                                        padding=auto_pad(3),
                                        normalization=nn.BatchNorm2d(conv_channels[0][1], eps=1e-3, momentum=3e-2),
                                        activation=nn.SiLU(inplace=True),
                                    ),
                                    "source": 0,
                                },
                            ),
                        )
                        self.module.append(
                            Conv2dModule(
                                conv_channels[1][0],
                                conv_channels[1][1],
                                3,
                                stride=2,
                                padding=auto_pad(3),
                                normalization=nn.BatchNorm2d(conv_channels[1][1], eps=1e-3, momentum=3e-2),
                                activation=nn.SiLU(inplace=True),
                            ),
                        )
                        self.module.append(RepNCSPELAN(csp_channel[0], csp_channel[1], part_channels=csp_channel[2]))
                    else:
                        self.module.append(
                            aux_aconv_adown_object(aux_aconv_adown_channel[0], aux_aconv_adown_channel[1]),
                        )
                        self.module.append(
                            set_info_into_instance({"module": CBFuse(cbfuse_index), "source": cbfuse_source}),
                        )
                        self.module.append(
                            set_info_into_instance(
                                {
                                    "module": RepNCSPELAN(csp_channel[0], csp_channel[1], part_channels=csp_channel[2]),
                                    "tags": f"A{idx+2}",
                                },
                            ),
                        )
                        aux_output_channels.append(csp_channel[1])

            self.module.append(
                set_info_into_instance(
                    {
                        "module": MultiheadDetection(aux_output_channels, num_classes),
                        "source": ["A3", "A4", "A5"],
                        "tags": "AUX",
                        "output": True,
                    },
                ),
            )

    @property
    def is_aux(self) -> bool:
        """Check if the head has an auxiliary head."""
        return bool(self.aux_cfg)

    def forward(self, outputs: dict[int | str, Tensor], *args, **kwargs) -> tuple[Tensor, None] | tuple[Tensor, Tensor]:
        """Forward function."""
        for layer in self.module:
            if hasattr(layer, "source") and isinstance(layer.source, list):
                model_input = [outputs[idx] for idx in layer.source]
            else:
                model_input = outputs[getattr(layer, "source", -1)]  # type: ignore[arg-type]
            x = layer(model_input)
            outputs[-1] = x
            if hasattr(layer, "tags"):
                outputs[layer.tags] = x

        if self.is_aux:
            return outputs["Main"], outputs["AUX"]
        return outputs["Main"], None

    def prepare_loss_inputs(self, x: tuple[Tensor], entity: DetBatchDataEntity) -> dict:
        """Perform forward propagation and loss calculation of the detection head.

        Args:
            x (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.
            entity (DetBatchDataEntity): Entity from OTX dataset.

        Returns:
            dict: A dictionary of loss components.
        """
        main_preds, aux_preds = self(x)

        padded_bboxes, padded_labels = self.pad_bbox_labels(entity.bboxes, entity.labels)
        merged_padded_labels_bboxes = torch.cat((padded_labels, padded_bboxes), dim=-1)
        return {
            "main_preds": main_preds,
            "aux_preds": aux_preds,
            "targets": merged_padded_labels_bboxes,
        }

    def loss_by_feat(self, *args, **kwargs) -> NoReturn:
        """Calculate the loss based on the features extracted by the detection head."""
        raise NotImplementedError

    def predict(
        self,
        x: tuple[Tensor],
        entity: OTXBatchDataEntity,
        rescale: bool = False,
    ) -> list[InstanceData]:
        """Perform forward propagation of the detection head and predict detection results.

        Args:
            x (tuple[Tensor]): Multi-level features from the
                upstream network, each is a 4D-tensor.
            entity (OTXBatchDataEntity): Entity from OTX dataset.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.

        Returns:
            list[InstanceData]: Detection results of each image
            after the post process.
        """
        main_preds, _ = self(x)

        pred_bboxes: Tensor | list[Tensor]
        pred_scores: Tensor | list[Tensor]
        pred_labels: Tensor | list[Tensor]

        prediction = self.vec2box(main_preds)
        pred_classes, _, pred_bboxes = prediction[:3]
        pred_scores = pred_classes.sigmoid() * (prediction[3] if len(prediction) == 4 else 1)

        # TODO (sungchul): use otx modules
        pred_scores, pred_labels = pred_scores.max(dim=-1, keepdim=True)
        if rescale:
            # rescale
            scale_factors = [img_info.scale_factor[::-1] for img_info in entity.imgs_info]  # type: ignore[index]
            pred_bboxes /= pred_bboxes.new_tensor(scale_factors).repeat((1, 2)).unsqueeze(1)

        if self.with_nms and pred_bboxes.numel():
            # filter class by confidence
            valid_mask = pred_scores > self.min_confidence
            valid_labels = pred_labels[valid_mask].float()
            valid_scores = pred_scores[valid_mask].float()
            valid_bboxes = pred_bboxes[valid_mask.repeat(1, 1, 4)].view(-1, 4)

            # nms
            batch_idx, *_ = torch.where(valid_mask)
            nms_idx = batched_nms(valid_bboxes, valid_scores, valid_labels, self.min_iou)

            def filter_predictions() -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
                _pred_bboxes = []
                _pred_scores = []
                _pred_labels = []
                for idx in range(pred_classes.size(0)):
                    instance_idx = nms_idx[idx == batch_idx[nms_idx]]
                    _pred_bboxes.append(valid_bboxes[instance_idx])
                    _pred_scores.append(valid_scores[instance_idx])
                    _pred_labels.append(valid_labels[instance_idx])
                return _pred_bboxes, _pred_scores, _pred_labels

            pred_bboxes, pred_scores, pred_labels = filter_predictions()

        return [
            InstanceData(
                bboxes=pred_bboxes[idx],
                scores=pred_scores[idx],
                labels=pred_labels[idx].type(torch.long),
            )
            for idx in range(pred_classes.size(0))
        ]

    def export(
        self,
        x: tuple[Tensor],
        batch_img_metas: list[dict],
        rescale: bool = False,
    ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
        """Perform forward propagation of the detection head and predict detection results.

        Args:
            x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor.
            batch_data_samples (list[dict]): The Data Samples. It usually includes information such as
                `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.

        Returns:
            tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
                Detection results of each image after the post process.
        """
        main_preds, _ = self(x)

        prediction = self.vec2box(main_preds)
        pred_class, _, pred_bbox = prediction[:3]
        pred_conf = prediction[3] if len(prediction) == 4 else None

        scores = pred_class.sigmoid() * (1 if pred_conf is None else pred_conf)

        return multiclass_nms(
            pred_bbox,
            scores,
            max_output_boxes_per_class=200,
            iou_threshold=self.min_iou,
            score_threshold=self.min_confidence,
            pre_top_k=5000,
            keep_top_k=100,
        )

    def pad_bbox_labels(self, bboxes: list[Tensor], labels: list[Tensor]) -> tuple[Tensor, Tensor]:
        """Pad bounding boxes and labels to the same length."""
        max_len = max(b.shape[0] for b in bboxes)
        padded_labels = torch.stack(
            [nn.functional.pad(label.unsqueeze(1), (0, 0, 0, max_len - label.shape[0]), value=-1) for label in labels],
            dim=0,
        )
        padded_bboxes = torch.stack(
            [nn.functional.pad(box, (0, 0, 0, max_len - box.shape[0]), value=0) for box in bboxes],
            dim=0,
        )
        return padded_bboxes, padded_labels


[docs] class YOLOHead: """YOLOHead factory for detection.""" YOLOHEAD_CFG: ClassVar[dict[str, Any]] = { "yolov9_s": { "csp_channels": [[320, 128, 128], [288, 192, 192], [384, 256, 256]], "aconv_channels": [[128, 96], [192, 128]], "concat_sources": [[-1, "N4"], [-1, "N3"]], "pre_upsample_concat_cfg": {"source": [-1, "B3"]}, "csp_args": {"repeat_num": 3}, "aux_cfg": { "sppelan_channels": [256, 256], "csp_channels": [[448, 192, 192], [320, 128, 128]], }, }, "yolov9_m": { "csp_channels": [[600, 240, 240], [544, 360, 360], [720, 480, 480]], "aconv_channels": [[240, 184], [360, 240]], "concat_sources": [[-1, "N4"], [-1, "N3"]], "aux_cfg": { "cblinear_channels": [[240, [240]], [360, [240, 360]], [480, [240, 360, 480]]], "csp_channels": [[64, 128, 128], [240, 240, 240], [360, 360, 360], [480, 480, 480]], "conv_channels": [[3, 32], [32, 64]], "aconv_channels": [[], [128, 240], [240, 360], [360, 480]], "cbfuse_indices": [[], [0, 0, 0], [1, 1], [2]], "cbfuse_sources": [[], ["R3", "R4", "R5", -1], ["R4", "R5", -1], ["R5", -1]], }, }, "yolov9_c": { "csp_channels": [[1024, 256, 256], [768, 512, 512], [1024, 512, 512]], "adown_channels": [[256, 256], [512, 512]], "concat_sources": [[-1, "N4"], [-1, "N3"]], "aux_cfg": { "cblinear_channels": [[512, [256]], [512, [256, 512]], [512, [256, 512, 512]]], "csp_channels": [[128, 256, 128], [256, 512, 256], [512, 512, 512], [512, 512, 512]], "conv_channels": [[3, 64], [64, 128]], "adown_channels": [[], [256, 256], [512, 512], [512, 512]], "cbfuse_indices": [[], [0, 0, 0], [1, 1], [2]], "cbfuse_sources": [[], ["R3", "R4", "R5", -1], ["R4", "R5", -1], ["R5", -1]], }, }, } def __new__(cls, model_name: str, num_classes: int) -> YOLOHeadModule: """Constructor for YOLOHead for v7 and v9.""" if model_name not in cls.YOLOHEAD_CFG: msg = f"model type '{model_name}' is not supported" raise KeyError(msg) return YOLOHeadModule( **cls.YOLOHEAD_CFG[model_name], num_classes=num_classes, )