otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything#

SAM module for visual prompting zero-shot learning.

Classes

PromptGetter(image_size[, downsizing])

Prompt getter for zero-shot learning.

ZeroShotSegmentAnything([config, ...])

Zero-shot learning module using Segment Anything.

class otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.PromptGetter(image_size: int, downsizing: int = 64)[source]#

Bases: Module

Prompt getter for zero-shot learning.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(image_embeddings: Tensor, reference_feat: Tensor, original_size: Tensor, threshold: Tensor = tensor([[0.]]), num_bg_points: Tensor = tensor([[1]])) Tuple[Tensor, Tensor][source]#

Get prompt candidates from given reference and target features.

get_prompt_candidates(image_embeddings: Tensor, reference_feats: Tensor, used_indices: Tensor, original_size: Tensor, threshold: Tensor = tensor([[0.]]), num_bg_points: Tensor = tensor([[1]]), device: device | str = device(type='cpu')) Tuple[Dict[int, Tensor], Dict[int, Tensor]][source]#

Get prompt candidates.

set_default_thresholds(default_threshold_reference: float, default_threshold_target: float) None[source]#

Set default thresholds.

class otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.ZeroShotSegmentAnything(config: DictConfig | None = None, manual_config_update: Dict | None = None, state_dict: OrderedDict | None = None)[source]#

Bases: SegmentAnything

Zero-shot learning module using Segment Anything.

configure_optimizers() None[source]#

Skip configure_optimizers unused in zero-shot learning.

expand_reference_info(new_largest_label: int) None[source]#

Expand reference info dimensions if newly given processed prompts have more lables.

infer(batch: List[Dict[str, Any]], reference_feats: ndarray | Tensor, used_indices: ndarray | Tensor, is_cascade: bool = True) List[List[DefaultDict[int, List[Tensor]]]][source]#

Zero-shot inference with reference features.

Get target results by using reference features and target images’ features.

Parameters:
  • batch (List[Dict[str, Any]]) – List of dictionaries containing images and metas.

  • reference_feats (Union[np.ndarray, Tensor]) – Reference features for target prediction. If it is np.ndarray, it will be converted to torch tensor.

  • used_indices (Union[np.ndarray, Tensor]) – To check which indices of reference features are validate. If it is np.ndarray, it will be converted to torch tensor.

  • is_cascade (bool) – Whether use cascade inference. Defaults to True.

Returns:

Target results.
Lists wrapping results is following this order:
  1. Target images

  2. Tuple of predicted masks and used points gotten by point selection

Return type:

(List[List[DefaultDict[int, List[Tensor]]]])

initialize_reference_info() None[source]#

Initialize reference information.

learn(batch: List[Dict[str, Any]], reset_feat: bool = False, is_cascade: bool = False) None | Tuple[ParameterDict, Tensor][source]#

Get reference features.

Using given images, get reference features and save it to PromptGetter. These reference features will be used for infer to get target results. Currently, single batch is only supported.

Parameters:
  • batch (List[Dict[str, Any]]) – List of dictionaries containing images, prompts, and metas. batch must contain images, prompts with bboxes, points, annotations, and polygons.

  • reset_feat (bool) – Whether reset reference_info. For OTX standalone, resetting reference_info will be conducted in on_train_start. For other frameworks, setting it to True is required to reset reference_info. Defaults to False.

  • is_cascade (bool) – Whether use cascade inference. Defaults to False.

Returns:

reference_info and ref_masks.

Return type:

(Tuple[ParameterDict, Tensor])

load_state_dict_pre_hook(state_dict: Dict[str, Any], prefix: str = '', *args, **kwargs) None[source]#

Load reference info manually.

on_predict_start() None[source]#

Called at the beginning of predicting.

on_train_start() None[source]#

Called at the beginning of training after sanity check.

predict_step(batch, batch_idx)[source]#

Predict step for infer.

set_default_config() DictConfig[source]#

Set default config when using independently.

set_empty_reference_info() None[source]#

Set empty reference information.

set_metrics() None[source]#

Skip set_metrics unused in zero-shot learning.

training_epoch_end(outputs) None[source]#

Called in the training loop at the very end of the epoch.

training_step(batch, batch_idx) None[source]#

Training step for learn.