# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""Implementation modified from mmdet.models.dense_heads.yolox_head.py.
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/mmdet/models/dense_heads/yolox_head.py
"""
from __future__ import annotations
import logging
import math
from functools import partial
from typing import Any, Callable, ClassVar, Sequence
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from torchvision.ops import box_convert
from otx.algo.common.utils.nms import batched_nms, multiclass_nms
from otx.algo.common.utils.prior_generators import MlvlPointGenerator
from otx.algo.common.utils.samplers import PseudoSampler
from otx.algo.common.utils.utils import multi_apply, reduce_mean
from otx.algo.detection.heads.base_head import BaseDenseHead
from otx.algo.modules.activation import Swish, build_activation_layer
from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import InstanceData
from otx.core.data.entity.detection import DetBatchDataEntity
logger = logging.getLogger()
class YOLOXHeadModule(BaseDenseHead):
"""YOLOXHead head used in `YOLOX <https://arxiv.org/abs/2107.08430>`_.
Args:
num_classes (int): Number of categories excluding the background category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels in stacking convs.
Defaults to 256
stacked_convs (int): Number of stacking convs of the head.
Defaults to (8, 16, 32).
strides (Sequence[int]): Downsample factor of each feature map.
Defaults to None.
use_depthwise (bool): Whether to depthwise separable convolution in blocks.
Defaults to False.
dcn_on_last_conv (bool): If true, use dcn in the last layer of towers.
Defaults to False.
conv_bias (bool or str): If specified as `auto`, it will be decided by
the normalization. Bias of conv will be set as True if `normalization` is
None, otherwise False. Defaults to "auto".
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``.
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``Swish``.
train_cfg (dict, optional): Training config of anchor head.
Defaults to None.
test_cfg (dict, optional): Testing config of anchor head.
Defaults to None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
use_sigmoid_cls (bool): Whether to use a sigmoid activation function for
classification prediction. Defaults to True.
"""
def __init__(
self,
num_classes: int,
in_channels: int,
feat_channels: int = 256,
stacked_convs: int = 2,
strides: Sequence[int] = (8, 16, 32),
use_depthwise: bool = False,
dcn_on_last_conv: bool = False,
conv_bias: bool | str = "auto",
normalization: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001),
activation: Callable[..., nn.Module] = Swish,
train_cfg: dict | None = None,
test_cfg: dict | None = None,
init_cfg: dict | list[dict] | None = None,
use_sigmoid_cls: bool = True,
) -> None:
if init_cfg is None:
init_cfg = {
"type": "Kaiming",
"layer": "Conv2d",
"a": math.sqrt(5),
"distribution": "uniform",
"mode": "fan_in",
"nonlinearity": "leaky_relu",
}
super().__init__(init_cfg=init_cfg, use_sigmoid_cls=use_sigmoid_cls)
self.num_classes = num_classes
self.cls_out_channels = num_classes
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.use_depthwise = use_depthwise
self.dcn_on_last_conv = dcn_on_last_conv
if conv_bias != "auto" and not isinstance(conv_bias, bool):
msg = f"conv_bias (={conv_bias}) should be bool or str."
raise ValueError(msg)
self.conv_bias = conv_bias
self.normalization = normalization
self.activation = activation
self.use_l1 = False # This flag will be modified by hooks.
self.prior_generator = MlvlPointGenerator(strides, offset=0) # type: ignore[arg-type]
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if self.train_cfg is not None:
self.assigner = self.train_cfg["assigner"]
# YOLOX does not support sampling
self.sampler = PseudoSampler() # type: ignore[no-untyped-call]
self._init_layers()
def _init_layers(self) -> None:
"""Initialize heads for all level feature maps."""
self.multi_level_cls_convs = nn.ModuleList()
self.multi_level_reg_convs = nn.ModuleList()
self.multi_level_conv_cls = nn.ModuleList()
self.multi_level_conv_reg = nn.ModuleList()
self.multi_level_conv_obj = nn.ModuleList()
for _ in self.strides:
self.multi_level_cls_convs.append(self._build_stacked_convs())
self.multi_level_reg_convs.append(self._build_stacked_convs())
conv_cls, conv_reg, conv_obj = self._build_predictor()
self.multi_level_conv_cls.append(conv_cls)
self.multi_level_conv_reg.append(conv_reg)
self.multi_level_conv_obj.append(conv_obj)
def _build_stacked_convs(self) -> nn.Sequential:
"""Initialize conv layers of a single level head."""
conv = DepthwiseSeparableConvModule if self.use_depthwise else Conv2dModule
stacked_convs = []
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
# TODO (sungchul): enable deformable convolution implemented in mmcv
# conv_cfg = {"type": "DCNv2"} if self.dcn_on_last_conv and i == self.stacked_convs - 1 else self.conv_cfg
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
logger.warning(
f"stacked convs[{i}] : Deformable convolution is not supported in YOLOXHead, "
"use normal convolution instead.",
)
stacked_convs.append(
conv(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
normalization=build_norm_layer(self.normalization, num_features=self.feat_channels),
activation=build_activation_layer(self.activation),
bias=self.conv_bias,
),
)
return nn.Sequential(*stacked_convs)
def _build_predictor(self) -> tuple[nn.Module, nn.Module, nn.Module]:
"""Initialize predictor layers of a single level head."""
conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
conv_reg = nn.Conv2d(self.feat_channels, 4, 1)
conv_obj = nn.Conv2d(self.feat_channels, 1, 1)
return conv_cls, conv_reg, conv_obj
def forward_single(
self,
x: Tensor,
cls_convs: nn.Module,
reg_convs: nn.Module,
conv_cls: nn.Module,
conv_reg: nn.Module,
conv_obj: nn.Module,
) -> tuple[Tensor, Tensor, Tensor]:
"""Forward feature of a single scale level."""
cls_feat = cls_convs(x)
reg_feat = reg_convs(x)
cls_score = conv_cls(cls_feat)
bbox_pred = conv_reg(reg_feat)
objectness = conv_obj(reg_feat)
return cls_score, bbox_pred, objectness
def forward(self, x: tuple[Tensor]) -> tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor.
Returns:
tuple[List]: A tuple of multi-level classification scores, bbox
predictions, and objectnesses.
"""
return multi_apply(
self.forward_single,
x,
self.multi_level_cls_convs,
self.multi_level_reg_convs,
self.multi_level_conv_cls,
self.multi_level_conv_reg,
self.multi_level_conv_obj,
)
def predict_by_feat( # type: ignore[override]
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
objectnesses: list[Tensor] | None,
batch_img_metas: list[dict] | None = None,
cfg: dict | None = None,
rescale: bool = False,
with_nms: bool = True,
) -> list[InstanceData]:
"""Transform a batch of output features extracted by the head into bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (dict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[InstanceData]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds) == len(objectnesses) # type: ignore[arg-type] # noqa: S101
cfg = cfg or self.test_cfg
num_imgs = len(batch_img_metas) # type: ignore[arg-type]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True,
)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) for cls_score in cls_scores
]
flatten_bbox_preds = [bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds]
flatten_objectness = [objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) for objectness in objectnesses] # type: ignore[union-attr]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
flatten_priors = torch.cat(mlvl_priors)
flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
result_list = []
for img_id, img_meta in enumerate(batch_img_metas): # type: ignore[arg-type]
max_scores, labels = torch.max(flatten_cls_scores[img_id], 1)
valid_mask = flatten_objectness[img_id] * max_scores >= cfg["score_thr"] # type: ignore[index]
results = InstanceData(
bboxes=flatten_bboxes[img_id][valid_mask],
scores=max_scores[valid_mask] * flatten_objectness[img_id][valid_mask],
labels=labels[valid_mask],
)
result_list.append(
self._bbox_post_process(
results=results,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta,
),
)
return result_list
def export_by_feat( # type: ignore[override]
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
objectnesses: list[Tensor],
batch_img_metas: list[dict] | None = None,
cfg: dict | None = None,
rescale: bool = False,
with_nms: bool = True,
) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
"""Transform network output for a batch into bbox predictions.
Reference : https://github.com/open-mmlab/mmdeploy/blob/v1.3.1/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py#L18-L118
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (dict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
device = cls_scores[0].device
cfg = cfg or self.test_cfg
batch_size = bbox_preds[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(featmap_sizes, device=device, with_stride=True)
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, self.cls_out_channels) for cls_score in cls_scores
]
flatten_bbox_preds = [bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) for bbox_pred in bbox_preds]
flatten_objectness = [objectness.permute(0, 2, 3, 1).reshape(batch_size, -1) for objectness in objectnesses]
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
score_factor = torch.cat(flatten_objectness, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_priors = torch.cat(mlvl_priors)
bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
# directly multiply score factor and feed to nms
scores = cls_scores * (score_factor.unsqueeze(-1))
if not with_nms:
return bboxes, scores
return multiclass_nms(
bboxes,
scores,
max_output_boxes_per_class=200, # TODO (sungchul): temporarily set to mmdeploy cfg, will be updated
iou_threshold=cfg["nms"]["iou_threshold"], # type: ignore[index]
score_threshold=cfg["score_thr"], # type: ignore[index]
pre_top_k=5000,
keep_top_k=cfg["max_per_img"], # type: ignore[index]
)
def _bbox_decode(self, priors: Tensor, bbox_preds: Tensor) -> Tensor:
"""Decode regression results (delta_x, delta_x, w, h) to bboxes (tl_x, tl_y, br_x, br_y).
Args:
priors (Tensor): Center priors of an image, has shape (num_instances, 2).
bbox_preds (Tensor): Box energies / deltas for all instances, has shape (batch_size, num_instances, 4).
Returns:
Tensor: Decoded bboxes in (tl_x, tl_y, br_x, br_y) format. Has
shape (batch_size, num_instances, 4).
"""
xys = (bbox_preds[..., :2] * priors[:, 2:]) + priors[:, :2]
whs = bbox_preds[..., 2:].exp() * priors[:, 2:]
tl_x = xys[..., 0] - whs[..., 0] / 2
tl_y = xys[..., 1] - whs[..., 1] / 2
br_x = xys[..., 0] + whs[..., 0] / 2
br_y = xys[..., 1] + whs[..., 1] / 2
return torch.stack([tl_x, tl_y, br_x, br_y], -1)
def _bbox_post_process( # type: ignore[override]
self,
results: InstanceData,
cfg: dict | None = None,
rescale: bool = False,
with_nms: bool = True,
img_meta: dict | None = None,
) -> InstanceData:
"""Bbox post-processing method.
The boxes would be rescaled to the original image scale and do
the nms operation. Usually `with_nms` is False is used for aug test.
Args:
results (InstanceData): Detection instance results,
each item has shape (num_bboxes, ).
cfg (dict): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default to False.
with_nms (bool): If True, do nms before return boxes.
Default to True.
img_meta (dict, optional): Image meta info. Defaults to None.
Returns:
InstanceData: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
if rescale:
assert img_meta.get("scale_factor") is not None # type: ignore[union-attr] # noqa: S101
results.bboxes /= results.bboxes.new_tensor(img_meta["scale_factor"][::-1]).repeat((1, 2)) # type: ignore[attr-defined, index]
if with_nms and results.bboxes.numel() > 0: # type: ignore[attr-defined]
det_bboxes, keep_idxs = batched_nms(results.bboxes, results.scores, results.labels, cfg["nms"]) # type: ignore[attr-defined, index]
results = results[keep_idxs]
# some nms would reweight the score, such as softnms
results.scores = det_bboxes[:, -1]
return results
def prepare_loss_inputs(
self,
x: tuple[Tensor],
entity: DetBatchDataEntity,
) -> dict | tuple:
"""Perform forward propagation of the detection head and prepare for loss calculation.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
entity (DetBatchDataEntity): Entity from OTX dataset.
Returns:
dict: A dictionary of components for loss calculation.
"""
(cls_scores, bbox_preds, objectnesses), batch_gt_instances, batch_img_metas = super().prepare_loss_inputs(
x,
entity,
)
num_imgs = len(batch_img_metas)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True,
)
flatten_cls_preds = torch.cat(
[cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) for cls_pred in cls_scores],
dim=1,
)
flatten_bbox_preds = torch.cat(
[bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds],
dim=1,
)
flatten_objectness = torch.cat(
[objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) for objectness in objectnesses],
dim=1,
)
flatten_priors = torch.cat(mlvl_priors)
flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
(pos_masks, cls_targets, obj_targets, bbox_targets, l1_targets, num_fg_imgs) = multi_apply(
self._get_targets_single,
flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1),
flatten_cls_preds.detach(),
flatten_bboxes.detach(),
flatten_objectness.detach(),
batch_gt_instances,
batch_img_metas,
)
# The experimental results show that 'reduce_mean' can improve
# performance on the COCO dataset.
num_pos = torch.tensor(sum(num_fg_imgs), dtype=torch.float, device=flatten_cls_preds.device)
num_total_samples = max(reduce_mean(num_pos), 1.0)
pos_masks = torch.cat(pos_masks, 0)
cls_targets = torch.cat(cls_targets, 0)
obj_targets = torch.cat(obj_targets, 0)
bbox_targets = torch.cat(bbox_targets, 0)
if self.use_l1:
l1_targets = torch.cat(l1_targets, 0)
return {
"flatten_objectness": flatten_objectness,
"flatten_cls_preds": flatten_cls_preds,
"flatten_bbox_preds": flatten_bbox_preds,
"flatten_bboxes": flatten_bboxes,
"obj_targets": obj_targets,
"cls_targets": cls_targets,
"bbox_targets": bbox_targets,
"l1_targets": l1_targets,
"num_total_samples": num_total_samples,
"num_pos": num_pos,
"pos_masks": pos_masks,
}
@torch.no_grad()
def _get_targets_single(
self,
priors: Tensor,
cls_preds: Tensor,
decoded_bboxes: Tensor,
objectness: Tensor,
gt_instances: InstanceData,
img_meta: dict,
gt_instances_ignore: InstanceData | None = None,
) -> tuple:
"""Compute classification, regression, and objectness targets for priors in a single image.
Args:
priors (Tensor): All priors of one image, a 2D-Tensor with shape
[num_priors, 4] in [cx, xy, stride_w, stride_y] format.
cls_preds (Tensor): Classification predictions of one image,
a 2D-Tensor with shape [num_priors, num_classes]
decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
br_x, br_y] format.
objectness (Tensor): Objectness predictions of one image,
a 1D-Tensor with shape [num_priors]
gt_instances (InstanceData): Ground truth of instance
annotations. It should includes ``bboxes`` and ``labels``
attributes.
img_meta (dict): Meta information for current image.
gt_instances_ignore (InstanceData, optional): Instances
to be ignored during training. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
tuple:
foreground_mask (list[Tensor]): Binary mask of foreground
targets.
cls_target (list[Tensor]): Classification targets of an image.
obj_target (list[Tensor]): Objectness targets of an image.
bbox_target (list[Tensor]): BBox targets of an image.
l1_target (int): BBox L1 targets of an image.
num_pos_per_img (int): Number of positive samples in an image.
"""
num_priors = priors.size(0)
num_gts = len(gt_instances)
# No target
if num_gts == 0:
cls_target = cls_preds.new_zeros((0, self.num_classes))
bbox_target = cls_preds.new_zeros((0, 4))
l1_target = cls_preds.new_zeros((0, 4))
obj_target = cls_preds.new_zeros((num_priors, 1))
foreground_mask = cls_preds.new_zeros(num_priors).bool()
return (foreground_mask, cls_target, obj_target, bbox_target, l1_target, 0)
# YOLOX uses center priors with 0.5 offset to assign targets,
# but use center priors without offset to regress bboxes.
offset_priors = torch.cat([priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1)
scores = cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid()
pred_instances = InstanceData(bboxes=decoded_bboxes, scores=scores.sqrt_(), priors=offset_priors)
assign_result = self.assigner.assign(
pred_instances=pred_instances,
gt_instances=gt_instances,
gt_instances_ignore=gt_instances_ignore,
)
sampling_result = self.sampler.sample(assign_result, pred_instances, gt_instances)
pos_inds = sampling_result.pos_inds
num_pos_per_img = pos_inds.size(0)
pos_ious = assign_result.max_overlaps[pos_inds]
# IOU aware classification score
cls_target = F.one_hot(sampling_result.pos_gt_labels, self.num_classes) * pos_ious.unsqueeze(-1)
obj_target = torch.zeros_like(objectness).unsqueeze(-1)
obj_target[pos_inds] = 1
bbox_target = sampling_result.pos_gt_bboxes
l1_target = cls_preds.new_zeros((num_pos_per_img, 4))
if self.use_l1:
l1_target = self._get_l1_target(l1_target, bbox_target, priors[pos_inds])
foreground_mask = torch.zeros_like(objectness).to(torch.bool)
foreground_mask[pos_inds] = 1
return (foreground_mask, cls_target, obj_target, bbox_target, l1_target, num_pos_per_img)
def _get_l1_target(self, l1_target: Tensor, gt_bboxes: Tensor, priors: Tensor, eps: float = 1e-8) -> Tensor:
"""Convert gt bboxes to center offset and log width height."""
gt_cxcywh = box_convert(gt_bboxes, in_fmt="xyxy", out_fmt="cxcywh")
l1_target[:, :2] = (gt_cxcywh[:, :2] - priors[:, :2]) / priors[:, 2:]
l1_target[:, 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
return l1_target
[docs]
class YOLOXHead:
"""YOLOXHead factory for detection."""
YOLOXHEAD_CFG: ClassVar[dict[str, Any]] = {
"yolox_tiny": {
"in_channels": 96,
"feat_channels": 96,
},
"yolox_s": {
"in_channels": 128,
"feat_channels": 128,
},
"yolox_l": {
"in_channels": 256,
"feat_channels": 256,
},
"yolox_x": {
"in_channels": 320,
"feat_channels": 320,
},
}
def __new__(
cls,
model_name: str,
num_classes: int,
train_cfg: dict,
test_cfg: dict | None = None,
) -> YOLOXHeadModule:
"""Constructor for YOLOXHead."""
if model_name not in cls.YOLOXHEAD_CFG:
msg = f"model type '{model_name}' is not supported"
raise KeyError(msg)
return YOLOXHeadModule(
**cls.YOLOXHEAD_CFG[model_name],
num_classes=num_classes,
train_cfg=train_cfg, # TODO (sungchul, kirill): remove
test_cfg=test_cfg, # TODO (sungchul, kirill): remove
)