[docs]classDetrCriterion(nn.Module):"""This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) Args: weight_dict (dict[str, int | float]): A dictionary containing the weights for different loss components. alpha (float, optional): The alpha parameter for the loss calculation. Defaults to 0.2. gamma (float, optional): The gamma parameter for the loss calculation. Defaults to 2.0. num_classes (int, optional): The number of classes. Defaults to 80. """def__init__(self,weight_dict:dict[str,int|float],alpha:float=0.2,gamma:float=2.0,num_classes:int=80,)->None:"""Create the criterion."""super().__init__()self.num_classes=num_classesself.matcher=HungarianMatcher(cost_dict={"cost_class":2,"cost_bbox":5,"cost_giou":2})loss_bbox_weight=weight_dict.get("loss_bbox",1.0)loss_giou_weight=weight_dict.get("loss_giou",1.0)self.loss_vfl_weight=weight_dict.get("loss_vfl",1.0)self.alpha=alphaself.gamma=gammaself.lossl1=L1Loss(loss_weight=loss_bbox_weight)self.giou=GIoULoss(loss_weight=loss_giou_weight)
[docs]defloss_labels_vfl(self,outputs:dict[str,torch.Tensor],targets:list[dict[str,torch.Tensor]],indices:list[tuple[int,int]],num_boxes:int,)->dict[str,torch.Tensor]:"""Compute the vfl loss. Args: outputs (dict[str, torch.Tensor]): Model outputs. targets (List[Dict[str, torch.Tensor]]): List of target dictionaries. indices (List[Tuple[int, int]]): List of tuples of indices. num_boxes (int): Number of predicted boxes. """idx=self._get_src_permutation_idx(indices)src_boxes=outputs["pred_boxes"][idx]target_boxes=torch.cat([t["boxes"][i]fort,(_,i)inzip(targets,indices)],dim=0)ious=bbox_overlaps(box_convert(src_boxes,in_fmt="cxcywh",out_fmt="xyxy"),box_convert(target_boxes,in_fmt="cxcywh",out_fmt="xyxy"),)ious=torch.diag(ious).detach()src_logits=outputs["pred_logits"]target_classes_o=torch.cat([t["labels"][J]fort,(_,J)inzip(targets,indices)])target_classes=torch.full(src_logits.shape[:2],self.num_classes,dtype=torch.int64,device=src_logits.device)target_classes[idx]=target_classes_o.long()target=nn.functional.one_hot(target_classes,num_classes=self.num_classes+1)[...,:-1]target_score_o=torch.zeros_like(target_classes,dtype=src_logits.dtype)target_score_o[idx]=ious.to(target_score_o.dtype)target_score=target_score_o.unsqueeze(-1)*targetpred_score=nn.functional.sigmoid(src_logits).detach()weight=self.alpha*pred_score.pow(self.gamma)*(1-target)+target_scoreloss=nn.functional.binary_cross_entropy_with_logits(src_logits,target_score,weight=weight,reduction="none")loss=loss.mean(1).sum()*src_logits.shape[1]/num_boxesreturn{"loss_vfl":loss*self.loss_vfl_weight}
[docs]defloss_boxes(self,outputs:dict[str,torch.Tensor],targets:list[dict[str,torch.Tensor]],indices:list[tuple[int,int]],num_boxes:int,)->dict[str,torch.Tensor]:"""Compute the losses re)L1 regression loss and the GIoU loss. Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. Args: outputs (dict[str, torch.Tensor]): The outputs of the model. targets (list[dict[str, torch.Tensor]]): The targets. indices (list[tuple[int, int]]): The indices of the matched boxes. num_boxes (int): The number of boxes. Returns: dict[str, torch.Tensor]: The losses. """idx=self._get_src_permutation_idx(indices)src_boxes=outputs["pred_boxes"][idx]target_boxes=torch.cat([t["boxes"][i]fort,(_,i)inzip(targets,indices)],dim=0)losses={}loss_bbox=self.lossl1(src_boxes,target_boxes,avg_factor=num_boxes)loss_giou=self.giou(box_convert(src_boxes,in_fmt="cxcywh",out_fmt="xyxy"),box_convert(target_boxes,in_fmt="cxcywh",out_fmt="xyxy"),avg_factor=num_boxes,)losses["loss_giou"]=loss_gioulosses["loss_bbox"]=loss_bboxreturnlosses
def_get_src_permutation_idx(self,indices:list[tuple[torch.Tensor,torch.Tensor]],)->tuple[torch.Tensor,torch.Tensor]:# permute predictions following indicesbatch_idx=torch.cat([torch.full_like(src,i)fori,(src,_)inenumerate(indices)])src_idx=torch.cat([srcfor(src,_)inindices])returnbatch_idx,src_idx@propertydef_available_losses(self)->tuple[Callable]:return(self.loss_boxes,self.loss_labels_vfl)# type: ignore[return-value]
[docs]defforward(self,outputs:dict[str,torch.Tensor],targets:list[dict[str,torch.Tensor]],)->dict[str,torch.Tensor]:"""This performs the loss computation. Args: outputs (dict[str, torch.Tensor]): dict of tensors, see the output specification of the model for the format targets (list[dict[str, torch.Tensor]]): list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """if"pred_boxes"notinoutputsor"pred_logits"notinoutputs:msg="The model should return the predicted boxes and logits"raiseValueError(msg)outputs_without_aux={k:vfork,vinoutputs.items()if"aux"notink}# Retrieve the matching between the outputs of the last layer and the targetsindices=self.matcher(outputs_without_aux,targets)# Compute the average number of target boxes accross all nodes, for normalization purposesnum_boxes=sum(len(t["labels"])fortintargets)num_boxes=torch.as_tensor([num_boxes],dtype=torch.float,device=next(iter(outputs.values())).device)world_size=1iftorch.distributed.is_available()andtorch.distributed.is_initialized():torch.distributed.all_reduce(num_boxes)world_size=torch.distributed.get_world_size()num_boxes=torch.clamp(num_boxes/world_size,min=1).item()# Compute all the requested losseslosses={}forlossinself._available_losses:l_dict=loss(outputs,targets,indices,num_boxes)losses.update(l_dict)# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.if"aux_outputs"inoutputs:fori,aux_outputsinenumerate(outputs["aux_outputs"]):indices=self.matcher(aux_outputs,targets)forlossinself._available_losses:ifloss=="masks":# Intermediate masks losses are too costly to compute, we ignore them.continuekwargs={}ifloss=="labels":# Logging is enabled only for the last layerkwargs={"log":False}l_dict=loss(aux_outputs,targets,indices,num_boxes,**kwargs)l_dict={k+f"_aux_{i}":vfork,vinl_dict.items()}losses.update(l_dict)# In case of cdn auxiliary losses. For rtdetrif"dn_aux_outputs"inoutputs:if"dn_meta"notinoutputs:msg="dn_meta is not in outputs"raiseValueError(msg)indices=self.get_cdn_matched_indices(outputs["dn_meta"],targets)num_boxes=num_boxes*outputs["dn_meta"]["dn_num_group"]fori,aux_outputsinenumerate(outputs["dn_aux_outputs"]):# indices = self.matcher(aux_outputs, targets)forlossinself._available_losses:kwargs={}l_dict=loss(aux_outputs,targets,indices,num_boxes,**kwargs)l_dict={k+f"_dn_{i}":vfork,vinl_dict.items()}losses.update(l_dict)returnlosses
[docs]@staticmethoddefget_cdn_matched_indices(dn_meta:dict[str,list[torch.Tensor]],targets:list[dict[str,torch.Tensor]],)->list[tuple[torch.Tensor,torch.Tensor]]:"""get_cdn_matched_indices. Args: dn_meta (dict[str, list[torch.Tensor]]): meta data for cdn targets (list[dict[str, torch.Tensor]]): targets """dn_positive_idx,dn_num_group=dn_meta["dn_positive_idx"],dn_meta["dn_num_group"]num_gts=[len(t["labels"])fortintargets]device=targets[0]["labels"].devicedn_match_indices=[]fori,num_gtinenumerate(num_gts):ifnum_gt>0:gt_idx=torch.arange(num_gt,dtype=torch.int64,device=device)gt_idx=gt_idx.tile(dn_num_group)iflen(dn_positive_idx[i])!=len(gt_idx):msg=f"len(dn_positive_idx[i]) != len(gt_idx), {len(dn_positive_idx[i])} != {len(gt_idx)}"raiseValueError(msg)dn_match_indices.append((dn_positive_idx[i],gt_idx))else:dn_match_indices.append((torch.zeros(0,dtype=torch.int64,device=device),torch.zeros(0,dtype=torch.int64,device=device),),)returndn_match_indices