otx.algo.visual_prompting#

Module for OTX visual prompting models.

Classes

SAM(backbone_type, ], label_info, ...)

OTX visual prompting model class for Segment Anything Model (SAM).

ZeroShotSAM(backbone_type, ], label_info, ...)

Zero-Shot Visual Prompting model.

class otx.algo.visual_prompting.SAM(backbone_type: Literal['tiny_vit', 'vit_b'], label_info: LabelInfoTypes = NullLabelInfo(label_names=[], label_ids=[], label_groups=[[]]), input_size: tuple[int, int] = (1024, 1024), optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _visual_prompting_metric_callable>, torch_compile: bool = False, freeze_image_encoder: bool = True, freeze_prompt_encoder: bool = True, freeze_mask_decoder: bool = False, use_stability_score: bool = False, return_single_mask: bool = True, return_extra_metrics: bool = False, stability_score_offset: float = 1.0)[source]#

Bases: CommonSettingMixin, OTXVisualPromptingModel

OTX visual prompting model class for Segment Anything Model (SAM).

class otx.algo.visual_prompting.ZeroShotSAM(backbone_type: Literal['tiny_vit', 'vit_b'], label_info: LabelInfoTypes = NullLabelInfo(label_names=[], label_ids=[], label_groups=[[]]), optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _visual_prompting_metric_callable>, torch_compile: bool = False, reference_info_dir: Path | str = 'reference_infos', infer_reference_info_root: Path | str = '../.latest/train', save_outputs: bool = True, pixel_mean: list[float] | None = [123.675, 116.28, 103.53], pixel_std: list[float] | None = [58.395, 57.12, 57.375], freeze_image_encoder: bool = True, freeze_prompt_encoder: bool = True, freeze_mask_decoder: bool = True, default_threshold_reference: float = 0.3, default_threshold_target: float = 0.65, use_stability_score: bool = False, return_single_mask: bool = False, return_extra_metrics: bool = False, stability_score_offset: float = 1.0)[source]#

Bases: CommonSettingMixin, OTXZeroShotVisualPromptingModel

Zero-Shot Visual Prompting model.

apply_boxes(boxes: BoundingBoxes, ori_shape: tuple[int, ...], target_length: int = 1024) BoundingBoxes[source]#

Preprocess boxes to be used in the model.

apply_coords(coords: Tensor, ori_shape: tuple[int, ...], target_length: int = 1024) Tensor[source]#

Preprocess points to be used in the model.

apply_image(image: Image | np.ndarray, target_length: int = 1024) Image[source]#

Preprocess image to be used in the model.

apply_points(points: Points, ori_shape: tuple[int, ...], target_length: int = 1024) Points[source]#

Preprocess points to be used in the model.

apply_prompts(prompts: list[Type[BoundingBoxes | Points | Mask | Polygon]], ori_shape: tuple[int, ...], target_length: int = 1024) list[Type[BoundingBoxes | Points | Mask | Polygon]][source]#

Preprocess prompts to be used in the model.

forward(inputs: ZeroShotVisualPromptingBatchDataEntity) ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity[source]#

Model forward function.

get_preprocess_shape(oldh: int, oldw: int, target_length: int) tuple[int, int][source]#

Get preprocess shape.

infer(inputs: ZeroShotVisualPromptingBatchDataEntity, reference_feats: Tensor | None = None, used_indices: Tensor | None = None, threshold: float = 0.0, num_bg_points: int = 1, is_cascade: bool = True) ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity[source]#

Infer to directly connect to the model.

initialize_reference_info() None[source]#

Initialize reference information.

learn(inputs: ZeroShotVisualPromptingBatchDataEntity, reference_feats: Tensor | None = None, used_indices: Tensor | None = None, reset_feat: bool = False, is_cascade: bool = False) ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity[source]#

Learn to directly connect to the model.

load_reference_info(default_root_dir: Path | str, device: str | device = 'cpu', path_to_directly_load: Path | None = None) bool[source]#

Load latest reference info to be used.

Parameters:
  • default_root_dir (Path | str) – Default root directory to be used when inappropriate infer_reference_info_root is given.

  • device (str | torch.device) – Device that reference infos will be attached.

  • path_to_directly_load (Path | None) – Reference info path to directly be loaded. Normally, it is obtained after learn which is executed when trying to do infer without reference features in on_test_start or on_predict_start.

Returns:

Whether normally loading checkpoint or not.

Return type:

(bool)

preprocess(x: Image) Image[source]#

Normalize pixel values and pad to a square input.

save_reference_info(default_root_dir: Path | str) None[source]#

Save reference info.

transforms(entity: ZeroShotVisualPromptingBatchDataEntity) ZeroShotVisualPromptingBatchDataEntity[source]#

Transforms for ZeroShotVisualPromptingBatchDataEntity.