otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything#

SAM module for visual prompting.

paper: https://arxiv.org/abs/2304.02643 reference: facebookresearch/segment-anything

Classes

SegmentAnything(config[, state_dict])

SAM predicts object masks from an image and input prompts.

class otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything(config: DictConfig, state_dict: OrderedDict | None = None)[source]#

Bases: LightningModule

SAM predicts object masks from an image and input prompts.

Parameters:
  • config (DictConfig) – Config for SAM.

  • state_dict (Optional[OrderedDict], optional) – State dict of SAM. Defaults to None.

calculate_dice_loss(inputs: Tensor, targets: Tensor, num_masks: int) Tensor[source]#

Compute the DICE loss, similar to generalized IOU for masks.

Parameters:
  • inputs (Tensor) – A tensor representing a mask.

  • targets (Tensor) – A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class).

  • num_masks (int) – The number of masks present in the current batch, used for normalization.

Returns:

The DICE loss.

Return type:

Tensor

calculate_iou(inputs: Tensor, targets: Tensor, epsilon: float = 1e-07) Tensor[source]#

Calculate the intersection over union (IOU) between the predicted mask and the ground truth mask.

Parameters:
  • inputs (Tensor) – A tensor representing a mask.

  • targets (Tensor) – A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class).

  • epsilon (float, optional, defaults to 1e-7) – A small value to prevent division by zero.

Returns:

The IOU between the predicted mask and the ground truth mask.

Return type:

Tensor

calculate_sigmoid_ce_focal_loss(inputs: Tensor, targets: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2) Tensor[source]#

Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. # noqa: D301.

Parameters:
  • inputs (Tensor) – A float tensor of arbitrary shape.

  • targets (Tensor) – A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class).

  • num_masks (int) – The number of masks present in the current batch, used for normalization.

  • alpha (float, optional, defaults to 0.25) – Weighting factor in range (0,1) to balance positive vs negative examples.

  • gamma (float, optional, defaults to 2.0) – Exponent of the modulating factor \(1 - p_t\) to balance easy vs hard examples.

Returns:

The focal loss.

Return type:

Tensor

calculate_stability_score(masks: Tensor, mask_threshold: float, threshold_offset: float = 1.0) Tensor[source]#

Computes the stability score for a batch of masks.

The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high and low values.

Parameters:
  • masks (Tensor) – A batch of predicted masks with shape BxHxW.

  • mask_threshold (float) – The threshold used to binarize the masks.

  • threshold_offset (float, optional) – The offset used to compute the stability score.

Returns:

The stability scores for the batch of masks.

Return type:

stability_scores (Tensor)

configure_optimizers() <module 'torch.optim' from '/home/yunchule/wsl-workspace/forks/training_extensions/.tox/build-doc/lib/python3.10/site-packages/torch/optim/__init__.py'>[source]#

Configure the optimizer for SAM.

Returns:

Optimizer.

Return type:

optim

forward(image_embeddings: Tensor, point_coords: Tensor, point_labels: Tensor, mask_input: Tensor, has_mask_input: Tensor, orig_size: Tensor)[source]#

Forward method for SAM inference (export/deploy).

Parameters:
  • image_embeddings (Tensor) – The image embedding with a batch index of length 1. If it is a zero tensor, the image embedding will be computed from the image.

  • point_coords (Tensor) – Coordinates of sparse input prompts, corresponding to both point inputs and box inputs. Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner. Coordinates must already be transformed to long-side 1024. Has a batch index of length 1.

  • point_labels (Tensor) – Labels for the sparse input prompts. 0 is a negative input point, 1 is a positive input point, 2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point. If there is no box input, a single padding point with label -1 and coordinates (0.0, 0.0) should be concatenated.

  • mask_input (Tensor) – A mask input to the model with shape 1x1x256x256. This must be supplied even if there is no mask input. In this case, it can just be zeros.

  • has_mask_input (Tensor) – An indicator for the mask input. 1 indicates a mask input, 0 indicates no mask input. This input has 1x1 shape due to supporting openvino input layout.

  • orig_size (Tensor) – The size of the input image in (H,W) format, before any transformation. This input has 1x2 shape due to supporting openvino input layout.

forward_train(images: Tensor, bboxes: List[Tensor], points: Tuple[Tensor, Tensor] | None = None, masks: Tensor | None = None) Tuple[List[Tensor], List[Tensor]][source]#

Forward method for SAM training/validation/prediction.

Parameters:
  • images (Tensor) – Images with shape (B, C, H, W).

  • bboxes (List[Tensor]) – A Nx4 array given a box prompt to the model, in XYXY format.

  • points (Tuple[Tensor, Tensor], optional) – Point coordinates and labels to embed. Point coordinates are BxNx2 arrays of point prompts to the model. Each point is in (X,Y) in pixels. Labels are BxN arrays of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.

  • masks (Optional[Tensor], optional) – A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form Bx1xHxW, where for SAM, H=W=256. Masks returned by a previous iteration of the predict method do not need further transformation.

Returns:

List with predicted masks with shape (B, 1, H, W). ious (List[Tensor]): List with IoU predictions with shape (N, 1).

Return type:

pred_masks (List[Tensor])

freeze_networks() None[source]#

Freeze networks depending on config.

get_prepadded_size(input_image_size: Tensor, longest_side: int) Tensor[source]#

Get pre-padded size.

load_checkpoint(state_dict: OrderedDict | None = None) None[source]#

Load checkpoint for SAM.

Parameters:

state_dict (Optional[OrderedDict], optional) – State dict of SAM. Defaults to None.

classmethod postprocess_masks(masks: Tensor, input_size: int, orig_size: Tensor) Tensor[source]#

Postprocess the predicted masks.

Parameters:
  • masks (Tensor) – A batch of predicted masks with shape Bx1xHxW.

  • input_size (int) – The size of the image input to the model. Used to remove padding.

  • orig_size (Tensor) – The original image size with shape Bx2.

Returns:

The postprocessed masks with shape Bx1xHxW.

Return type:

masks (Tensor)

predict_step(batch, batch_idx) Dict[str, Tensor][source]#

Predict step of SAM.

Parameters:
  • batch (Dict) – Batch data.

  • batch_idx (int) – Batch index.

Returns:

Predicted masks, IoU predictions, image paths, and labels.

Return type:

Dict[str, Tensor]

select_masks(masks: Tensor, iou_preds: Tensor, num_points: int) Tuple[Tensor, Tensor][source]#

Selects the best mask from a batch of masks.

Parameters:
  • masks (Tensor) – A batch of predicted masks with shape BxMxHxW.

  • iou_preds (Tensor) – A batch of predicted IoU scores with shape BxM.

  • num_points (int) – The number of points in the input.

Returns:

The selected masks with shape Bx1xHxW. iou_preds (Tensor): The selected IoU scores with shape Bx1.

Return type:

masks (Tensor)

set_metrics() None[source]#

Set metrics for SAM.

set_models() None[source]#

Set models for SAM.

training_epoch_end(outputs) None[source]#

Training epoch end for SAM.

training_step(batch, batch_idx) Tensor[source]#

Training step for SAM.

Parameters:
  • batch (Dict) – Batch data.

  • batch_idx (int) – Batch index.

Returns:

Loss tensor.

Return type:

loss (Tensor)

validation_epoch_end(outputs) None[source]#

Validation epoch end for SAM.

validation_step(batch, batch_idx) MetricCollection[source]#

Validation step of SAM.

Parameters:
  • batch (Dict) – Batch data.

  • batch_idx (int) – Batch index.

Returns:

Validation metrics.

Return type:

val_metrics (MetricCollection)