otx.algo.segmentation.segmentors#

Module for base NN segmentation models.

Classes

BaseSegmentationModel(backbone, decode_head)

Base Segmentation Model.

class otx.algo.segmentation.segmentors.BaseSegmentationModel(backbone: Module, decode_head: Module, criterion: Module | None = None)[source]#

Bases: Module

Base Segmentation Model.

Parameters:
  • backbone (nn.Module) – The backbone of the segmentation model.

  • decode_head (nn.Module) – The decode head of the segmentation model.

  • criterion (nn.Module, optional) – The criterion of the model. Defaults to None. If None, use CrossEntropyLoss with ignore_index=255.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

calculate_loss(model_features: Tensor, img_metas: list[ImageInfo], masks: Tensor, interpolate: bool) Tensor[source]#

Calculates the loss of the model.

Parameters:
  • model_features (Tensor) – model outputs of the model.

  • img_metas (list[ImageInfo]) – Image meta information. Defaults to None.

  • masks (Tensor) – Ground truth masks for training. Defaults to None.

Returns:

The loss of the model.

Return type:

Tensor

extract_features(inputs: Tensor) tuple[Tensor, Tensor][source]#

Extract features from the backbone and head.

forward(inputs: Tensor, img_metas: list[ImageInfo] | None = None, masks: Tensor | None = None, mode: str = 'tensor') Tensor[source]#

Performs the forward pass of the model.

Parameters:
  • inputs (Tensor) – Input images to the model.

  • img_metas (list[ImageInfo]) – Image meta information. Defaults to None.

  • masks (Tensor) – Ground truth masks for training. Defaults to None.

  • mode (str) – The mode of operation. Defaults to “tensor”.

Returns:

  • If mode is “tensor”, returns the model outputs.

  • If mode is “loss”, returns a dictionary of output losses.

  • If mode is “predict”, returns the predicted outputs.

  • Otherwise, returns the model outputs after interpolation.

Return type:

Depending on the mode

get_valid_label_mask(img_metas: list[ImageInfo]) list[Tensor][source]#

Get valid label mask removing ignored classes to zero mask in a batch.

Parameters:

img_metas (List[dict]) – List of image metadata.

Returns:

List of valid label masks.

Return type:

List[torch.Tensor]