otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything#
SAM module for visual prompting zero-shot learning.
Classes
|
Prompt getter for zero-shot learning. |
|
Zero-shot learning module using Segment Anything. |
- class otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.PromptGetter(image_size: int, downsizing: int = 64)[source]#
Bases:
Module
Prompt getter for zero-shot learning.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(image_embeddings: Tensor, reference_feat: Tensor, original_size: Tensor, threshold: Tensor = tensor([[0.]]), num_bg_points: Tensor = tensor([[1]])) Tuple[Tensor, Tensor] [source]#
Get prompt candidates from given reference and target features.
- get_prompt_candidates(image_embeddings: Tensor, reference_feats: Tensor, used_indices: Tensor, original_size: Tensor, threshold: Tensor = tensor([[0.]]), num_bg_points: Tensor = tensor([[1]]), device: device | str = device(type='cpu')) Tuple[Dict[int, Tensor], Dict[int, Tensor]] [source]#
Get prompt candidates.
- class otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.ZeroShotSegmentAnything(config: DictConfig | None = None, manual_config_update: Dict | None = None, state_dict: OrderedDict | None = None)[source]#
Bases:
SegmentAnything
Zero-shot learning module using Segment Anything.
- expand_reference_info(new_largest_label: int) None [source]#
Expand reference info dimensions if newly given processed prompts have more lables.
- infer(batch: List[Dict[str, Any]], reference_feats: ndarray | Tensor, used_indices: ndarray | Tensor, is_cascade: bool = True) List[List[DefaultDict[int, List[Tensor]]]] [source]#
Zero-shot inference with reference features.
Get target results by using reference features and target images’ features.
- Parameters:
batch (List[Dict[str, Any]]) – List of dictionaries containing images and metas.
reference_feats (Union[np.ndarray, Tensor]) – Reference features for target prediction. If it is np.ndarray, it will be converted to torch tensor.
used_indices (Union[np.ndarray, Tensor]) – To check which indices of reference features are validate. If it is np.ndarray, it will be converted to torch tensor.
is_cascade (bool) – Whether use cascade inference. Defaults to True.
- Returns:
- Target results.
- Lists wrapping results is following this order:
Target images
Tuple of predicted masks and used points gotten by point selection
- Return type:
(List[List[DefaultDict[int, List[Tensor]]]])
- learn(batch: List[Dict[str, Any]], reset_feat: bool = False, is_cascade: bool = False) None | Tuple[ParameterDict, Tensor] [source]#
Get reference features.
Using given images, get reference features and save it to PromptGetter. These reference features will be used for infer to get target results. Currently, single batch is only supported.
- Parameters:
batch (List[Dict[str, Any]]) – List of dictionaries containing images, prompts, and metas. batch must contain images, prompts with bboxes, points, annotations, and polygons.
reset_feat (bool) – Whether reset reference_info. For OTX standalone, resetting reference_info will be conducted in on_train_start. For other frameworks, setting it to True is required to reset reference_info. Defaults to False.
is_cascade (bool) – Whether use cascade inference. Defaults to False.
- Returns:
reference_info and ref_masks.
- Return type:
(Tuple[ParameterDict, Tensor])
- load_state_dict_pre_hook(state_dict: Dict[str, Any], prefix: str = '', *args, **kwargs) None [source]#
Load reference info manually.