otx.algorithms.detection.adapters.mmdet.models.losses#

Loss list of mmdetection adapters.

Classes

CrossSigmoidFocalLoss([use_sigmoid, ...])

CrossSigmoidFocalLoss class for ignore labels with sigmoid.

L2SPLoss(model, model_ckpt[, loss_weight])

L2-SP regularization Class for mmdetection adapter.

OrdinaryFocalLoss([gamma])

Focal loss without balancing.

class otx.algorithms.detection.adapters.mmdet.models.losses.CrossSigmoidFocalLoss(use_sigmoid=True, num_classes=None, gamma=2.0, alpha=0.25, reduction='mean', loss_weight=1.0, ignore_index=None)[source]#

Bases: Module

CrossSigmoidFocalLoss class for ignore labels with sigmoid.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(pred, targets, weight=None, reduction_override=None, avg_factor=None, use_vfl=False, valid_label_mask=None, **kwargs)[source]#

Forward funtion of CrossSigmoidFocalLoss.

class otx.algorithms.detection.adapters.mmdet.models.losses.L2SPLoss(model, model_ckpt, loss_weight=0.0001)[source]#

Bases: Module

L2-SP regularization Class for mmdetection adapter.

L2-SP regularization loss.

Parameters:
  • model (nn.Module) – Input module to regularize

  • model_ckpt (str) – Starting-point model checkpoint Matched params in model would be regularized to be close to starting-point params

  • loss_weight (float, optional) – Weight of the loss. Defaults to 0.0001

forward(**kwargs)[source]#

Forward function.

Returns:

The calculated loss

Return type:

torch.Tensor

class otx.algorithms.detection.adapters.mmdet.models.losses.OrdinaryFocalLoss(gamma=1.5, **kwargs)[source]#

Bases: Module

Focal loss without balancing.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input, target, label_weights=None, avg_factor=None, reduction='mean', **kwars)[source]#

Forward function for focal loss.