otx.algo.visual_prompting#
Module for OTX visual prompting models.
Classes
|
OTX visual prompting model class for Segment Anything Model (SAM). |
|
Zero-Shot Visual Prompting model. |
- class otx.algo.visual_prompting.SAM(backbone_type: Literal['tiny_vit', 'vit_b'], label_info: LabelInfoTypes = NullLabelInfo(label_names=[], label_ids=[], label_groups=[[]]), input_size: tuple[int, int] = (1024, 1024), optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _visual_prompting_metric_callable>, torch_compile: bool = False, freeze_image_encoder: bool = True, freeze_prompt_encoder: bool = True, freeze_mask_decoder: bool = False, use_stability_score: bool = False, return_single_mask: bool = True, return_extra_metrics: bool = False, stability_score_offset: float = 1.0)[source]#
Bases:
CommonSettingMixin
,OTXVisualPromptingModel
OTX visual prompting model class for Segment Anything Model (SAM).
- class otx.algo.visual_prompting.ZeroShotSAM(backbone_type: Literal['tiny_vit', 'vit_b'], label_info: LabelInfoTypes = NullLabelInfo(label_names=[], label_ids=[], label_groups=[[]]), optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _visual_prompting_metric_callable>, torch_compile: bool = False, reference_info_dir: Path | str = 'reference_infos', infer_reference_info_root: Path | str = '../.latest/train', save_outputs: bool = True, pixel_mean: list[float] | None = [123.675, 116.28, 103.53], pixel_std: list[float] | None = [58.395, 57.12, 57.375], freeze_image_encoder: bool = True, freeze_prompt_encoder: bool = True, freeze_mask_decoder: bool = True, default_threshold_reference: float = 0.3, default_threshold_target: float = 0.65, use_stability_score: bool = False, return_single_mask: bool = False, return_extra_metrics: bool = False, stability_score_offset: float = 1.0)[source]#
Bases:
CommonSettingMixin
,OTXZeroShotVisualPromptingModel
Zero-Shot Visual Prompting model.
- apply_boxes(boxes: BoundingBoxes, ori_shape: tuple[int, ...], target_length: int = 1024) BoundingBoxes [source]#
Preprocess boxes to be used in the model.
- apply_coords(coords: Tensor, ori_shape: tuple[int, ...], target_length: int = 1024) Tensor [source]#
Preprocess points to be used in the model.
- apply_image(image: Image | np.ndarray, target_length: int = 1024) Image [source]#
Preprocess image to be used in the model.
- apply_points(points: Points, ori_shape: tuple[int, ...], target_length: int = 1024) Points [source]#
Preprocess points to be used in the model.
- apply_prompts(prompts: list[Type[BoundingBoxes | Points | Mask | Polygon]], ori_shape: tuple[int, ...], target_length: int = 1024) list[Type[BoundingBoxes | Points | Mask | Polygon]] [source]#
Preprocess prompts to be used in the model.
- forward(inputs: ZeroShotVisualPromptingBatchDataEntity) ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity [source]#
Model forward function.
- get_preprocess_shape(oldh: int, oldw: int, target_length: int) tuple[int, int] [source]#
Get preprocess shape.
- infer(inputs: ZeroShotVisualPromptingBatchDataEntity, reference_feats: Tensor | None = None, used_indices: Tensor | None = None, threshold: float = 0.0, num_bg_points: int = 1, is_cascade: bool = True) ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity [source]#
Infer to directly connect to the model.
- learn(inputs: ZeroShotVisualPromptingBatchDataEntity, reference_feats: Tensor | None = None, used_indices: Tensor | None = None, reset_feat: bool = False, is_cascade: bool = False) ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity [source]#
Learn to directly connect to the model.
- load_reference_info(default_root_dir: Path | str, device: str | device = 'cpu', path_to_directly_load: Path | None = None) bool [source]#
Load latest reference info to be used.
- Parameters:
default_root_dir (Path | str) – Default root directory to be used when inappropriate infer_reference_info_root is given.
device (str | torch.device) – Device that reference infos will be attached.
path_to_directly_load (Path | None) – Reference info path to directly be loaded. Normally, it is obtained after learn which is executed when trying to do infer without reference features in on_test_start or on_predict_start.
- Returns:
Whether normally loading checkpoint or not.
- Return type:
(bool)