Visual Prompting#

class model_api.models.visual_prompting.Prompt(data, label)#

Bases: NamedTuple

Create new instance of Prompt(data, label)

data: ndarray#

Alias for field number 0

label: int | ndarray#

Alias for field number 1

class model_api.models.visual_prompting.SAMLearnableVisualPrompter(encoder_model, decoder_model, reference_features=None, threshold=0.65)#

Bases: object

A wrapper that provides ZSL Visual Prompting workflow. To obtain segmentation results, one should run learn() first to obtain the reference features, or use previously generated ones.

Initializes ZSL pipeline.

Parameters:
  • encoder_model (SAMImageEncoder) – initialized decoder wrapper

  • decoder_model (SAMDecoder) – initialized encoder wrapper

  • reference_features (VisualPromptingFeatures | None, optional) – Previously generated reference features. Once the features are passed, one can skip learn() method, and start predicting masks right away. Defaults to None.

  • threshold (float, optional) – Threshold to match vs reference features on infer(). Greater value means a

  • 0.65. (stricter matching. Defaults to)

__call__(image, reference_features=None, apply_masks_refinement=True)#

A wrapper of the SAMLearnableVisualPrompter.infer() method

Return type:

ZSLVisualPromptingResult

has_reference_features()#

Checks if reference features are stored in the object state.

Return type:

bool

infer(image, reference_features=None, apply_masks_refinement=True)#

Obtains masks by already prepared reference features.

Reference features can be obtained with SAMLearnableVisualPrompter.learn() and passed as an argument. If the features are not passed, instance internal state will be used as a source of the features.

Parameters:
  • image (np.ndarray) – HWC-shaped image

  • reference_features (VisualPromptingFeatures | None, optional) – Reference features object obtained during previous learn() calls. If not passed, object internal state is used, which reflects the last learn() call. Defaults to None.

  • apply_masks_refinement (bool, optional) – Flag controlling additional refinement stage on inference.

  • enabled (Once)

  • decoder (decoder will be launched 2 extra times to refine the masks obtained with the first)

  • True. (call. Defaults to)

Returns:

Mapping label -> predicted mask. Each mask object contains a list of binary masks,

and a list of related prompts. Each binary mask corresponds to one prompt point. Class mask can be obtained by applying OR operation to all mask corresponding to one label.

Return type:

ZSLVisualPromptingResult

learn(image, boxes=None, points=None, polygons=None, reset_features=False)#

Executes learn stage of SAM ZSL pipeline.

Reference features are updated according to newly arrived prompts. Features corresponding to the same labels are overridden during consequent learn() calls.

Parameters:
  • image (np.ndarray) – HWC-shaped image

  • boxes (list[Prompt] | None, optional) – Prompts containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None.

  • points (list[Prompt] | None, optional) – Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None.

  • polygons (Optional[list[Prompt]]) – (list[Prompt] | None): Prompts containing polygons (a sequence of points in XY format) and their labels (ints, one per polygon). Polygon prompts are used to mask out the source features without implying decoder usage. Defaults to None.

  • reset_features (bool, optional) – Forces learning from scratch. Defaults to False.

Returns:

return values are the updated VPT reference features and

reference masks.

The shape of the reference mask is N_labels x H x W, where H and W are the same as in the input image.

Return type:

tuple[VisualPromptingFeatures, np.ndarray]

reset_reference_info()#

Initialize reference information.

Return type:

None

property reference_features: VisualPromptingFeatures#

Property represents reference features. An exception is thrown if called when the features are not presented in the internal object state.

class model_api.models.visual_prompting.SAMVisualPrompter(encoder_model, decoder_model)#

Bases: object

A wrapper that implements SAM Visual Prompter.

Segmentation results can be obtained by calling infer() method with corresponding parameters.

__call__(image, boxes=None, points=None)#

A wrapper of the SAMVisualPrompter.infer() method

Return type:

VisualPromptingResult

infer(image, boxes=None, points=None)#

Obtains segmentation masks using given prompts.

Parameters:
  • image (np.ndarray) – HWC-shaped image

  • boxes (list[Prompt] | None, optional) – Prompts containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None.

  • points (list[Prompt] | None, optional) – Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None.

Returns:

result object containing predicted masks and aux information.

Return type:

VisualPromptingResult

class model_api.models.visual_prompting.VisualPromptingFeatures(feature_vectors, used_indices)#

Bases: NamedTuple

Create new instance of VisualPromptingFeatures(feature_vectors, used_indices)

feature_vectors: ndarray#

Alias for field number 0

used_indices: ndarray#

Alias for field number 1