Source code for otx.algorithms.detection.adapters.mmdet.models.layers.dino

"""Custom DINO transformer for OTX template."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import Dict, List, Optional, Tuple, Union

import torch
from mmdet.models.utils.builder import TRANSFORMER
from mmdet.models.utils.transformer import DeformableDetrTransformer
from torch import Tensor, nn


[docs] @TRANSFORMER.register_module() class CustomDINOTransformer(DeformableDetrTransformer): """Custom DINO transformer. Original implementation: mmdet.models.utils.transformer.DeformableDETR in mmdet2.x What's changed: The forward function is modified. Modified implementations come from mmdet.models.detectors.dino.DINO in mmdet3.x """
[docs] def init_layers(self): """Initialize layers of the DINO. Unlike Deformable DETR, DINO does not need pos_trans, pos_trans_norm. """ self.level_embeds = torch.nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims)) self.enc_output = torch.nn.Linear(self.embed_dims, self.embed_dims) self.enc_output_norm = torch.nn.LayerNorm(self.embed_dims)
[docs] def forward( self, batch_info: List[Dict[str, Union[Tuple, Tensor]]], mlvl_feats: List[Tensor], mlvl_masks: List[Tensor], query_embed: Tensor, mlvl_pos_embeds: List[Tensor], reg_branches: Optional[nn.ModuleList] = None, cls_branches: Optional[nn.ModuleList] = None, **kwargs ): """Forward function for `Transformer`. What's changed: In mmdet3.x forward of transformer is divided into pre_transformer() -> forward_encoder() -> pre_decoder() -> forward_decoder(). In comparison, mmdet2.x forward function takes charge of all functions above. The differences in Deformable DETR and DINO are occured in pre_decoder(), forward_decoder(). Therefore this function modified those parts. Modified implementations come from pre_decoder(), and forward_decoder() of mmdet.models.detectors.dino.DINO in mmdet3.x. Args: batch_info(list(dict(str, union(tuple, tensor)))): Information about batch such as image shape, gt information. mlvl_feats (list(Tensor)): Input queries from different level. Each element has shape [bs, embed_dims, h, w]. mlvl_masks (list(Tensor)): The key_padding_mask from different level used for encoder and decoder, each element has shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. mlvl_pos_embeds (list(Tensor)): The positional encoding of feats from different level, has the shape [bs, embed_dims, h, w]. reg_branches (obj:`nn.ModuleList`): Regression heads for feature maps from each decoder layer. Only would be passed when `with_box_refine` is True. Default to None. cls_branches (obj:`nn.ModuleList`): Classification heads for feature maps from each decoder layer. Only would be passed when `as_two_stage` is True. Default to None. kwargs: Additional argument for forward_transformer function. Returns: tuple[Tensor]: results of decoder containing the following tensor. - inter_states: Outputs from decoder. If return_intermediate_dec is True output has shape \ (num_dec_layers, bs, num_query, embed_dims), else has \ shape (1, bs, num_query, embed_dims). - inter_references_out: The internal value of reference \ points in decoder, has shape \ (num_dec_layers, bs,num_query, embed_dims) - enc_outputs_class: The classification score of \ proposals generated from \ encoder's feature maps, has shape \ (batch, h*w, num_classes). \ Only would be returned when `as_two_stage` is True, \ otherwise None. - enc_outputs_coord_unact: The regression results \ generated from encoder's feature maps., has shape \ (batch, h*w, 4). Only would \ be returned when `as_two_stage` is True, \ otherwise None. - dn_meta (Dict[str, int]): The dictionary saves information about group collation, including 'num_denoising_queries' and 'num_denoising_groups'. It will be used for split outputs of denoising and matching parts and loss calculation. """ feat_flatten: Union[Tensor, List[Tensor]] = [] mask_flatten: Union[Tensor, List[Tensor]] = [] lvl_pos_embed_flatten: Union[Tensor, List[Tensor]] = [] spatial_shapes: Union[Tensor, List[Tensor]] = [] for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) memory = self.encoder( query=feat_flatten, key=None, value=None, query_pos=lvl_pos_embed_flatten, query_key_padding_mask=mask_flatten, spatial_shapes=spatial_shapes, reference_points=reference_points, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs ) # pre_decoder part at mmdet 3.x version memory = memory.permute(1, 0, 2) bs, _, c = memory.shape cls_out_features = cls_branches[self.decoder.num_layers].out_features output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory) enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals topk_indices = torch.topk(enc_outputs_class.max(-1)[0], k=self.two_stage_num_proposals, dim=1)[1] topk_scores = torch.gather(enc_outputs_class, 1, topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4)) topk_coords = topk_coords_unact.sigmoid() topk_coords_unact = topk_coords_unact.detach() query = query_embed[:, None, :] query = query.repeat(1, bs, 1).transpose(0, 1) if self.training: dn_label_query, dn_bbox_query, dn_mask, dn_meta = self.dn_query_generator(batch_info) query = torch.cat([dn_label_query, query], dim=1) reference_points = torch.cat([dn_bbox_query, topk_coords_unact], dim=1) else: reference_points = topk_coords_unact dn_mask, dn_meta = None, None reference_points = reference_points.sigmoid() # forward_decoder part in mmdet 3.x inter_states, references = self.decoder( query=query, value=memory, key_padding_mask=mask_flatten, self_attn_mask=dn_mask, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, ) if len(query) == self.two_stage_num_proposals: # NOTE: This is to make sure label_embeding can be involved to # produce loss even if there is no denoising query (no ground truth # target in this GPU), otherwise, this will raise runtime error in # distributed training. inter_states[0] += self.dn_query_generator.label_embedding.weight[0, 0] * 0.0 return inter_states, list(references), topk_scores, topk_coords, dn_meta