# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Segment Anything model for the OTX visual prompting."""
from __future__ import annotations
import logging as log
import pickle # nosec B403 used pickle for dumping object
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
import torch
import torchvision.transforms.v2 as tvt_v2
from torch import Tensor, nn
from torchvision.tv_tensors import BoundingBoxes, Image
from otx.algo.visual_prompting.decoders import SAMMaskDecoder
from otx.algo.visual_prompting.encoders import SAMImageEncoder, SAMPromptEncoder
from otx.algo.visual_prompting.losses.sam_loss import SAMCriterion
from otx.algo.visual_prompting.visual_prompters import SegmentAnything, ZeroShotSegmentAnything
from otx.core.data.entity.base import OTXBatchLossEntity, Points
from otx.core.data.entity.visual_prompting import (
ZeroShotPromptType,
ZeroShotVisualPromptingBatchDataEntity,
ZeroShotVisualPromptingBatchPredEntity,
)
from otx.core.metrics.visual_prompting import VisualPromptingMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.visual_prompting import OTXVisualPromptingModel, OTXZeroShotVisualPromptingModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes, NullLabelInfo
if TYPE_CHECKING:
import numpy as np
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
class CommonSettingMixin:
"""Mixin class for common settings in SAM.
Attributes:
model (nn.Module): The model used in SAM.
load_from (ClassVar[dict[str, str]]): A dictionary containing the URLs to load checkpoints from.
"""
model: nn.Module
load_from: ClassVar[dict[str, str]] = {
"tiny_vit": "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
def load_state_dict(
self,
state_dict: dict[str, Any] | None = None,
strict: bool = True,
assign: bool = False,
load_from: str | None = None,
) -> None:
"""Load checkpoint for SAM.
This method loads a pre-trained state dictionary for the SAM model. It can load from
a provided state dictionary or from a URL specified in the `load_from` parameter.
Args:
state_dict (dict[str, Any] | None, optional): The state dictionary to load.
Defaults to None.
strict (bool, optional): Whether to strictly enforce that the keys in state_dict
match the keys returned by this module's state_dict() function. Defaults to True.
assign (bool, optional): Whether to copy parameters instead of moving them.
Defaults to False.
load_from (str | None, optional): URL to load the checkpoint from. If provided,
this will be used instead of the state_dict argument. Defaults to None.
Raises:
ValueError: If the checkpoint format is not desirable for torch.hub.load_state_dict_from_url.
Note:
If loading from a URL, some keys are removed from the loaded state dictionary
and a 'model.' prefix is added to all remaining keys.
"""
try:
if load_from is not None:
_state_dict: dict[str, Any] = torch.hub.load_state_dict_from_url(str(load_from))
for key in [
"image_encoder.norm_head.weight",
"image_encoder.norm_head.bias",
"image_encoder.head.weight",
"image_encoder.head.bias",
]:
if key in _state_dict:
_state_dict.pop(key)
# add prefix 'model.' to all keys
for key in list(_state_dict.keys()):
_state_dict["model." + key] = _state_dict.pop(key)
state_dict = _state_dict
super().load_state_dict(state_dict, strict, assign) # type: ignore[misc]
except (ValueError, RuntimeError) as e:
log.info(
f"{e}: {load_from} is not desirable format for torch.hub.load_state_dict_from_url. "
f"To manually load {load_from}, try to set it to trainer.checkpoint.",
)
def freeze_networks(
self,
freeze_image_encoder: bool,
freeze_prompt_encoder: bool,
freeze_mask_decoder: bool,
) -> None:
"""Freeze networks depending on config.
Args:
freeze_image_encoder (bool): Whether to freeze the image encoder.
freeze_prompt_encoder (bool): Whether to freeze the prompt encoder.
freeze_mask_decoder (bool): Whether to freeze the mask decoder.
"""
for param in self.model.image_encoder.parameters():
param.requires_grad = not freeze_image_encoder
for param in self.model.prompt_encoder.parameters():
param.requires_grad = not freeze_prompt_encoder
for param in self.model.mask_decoder.parameters():
param.requires_grad = not freeze_mask_decoder
@torch.no_grad()
def forward_for_tracing(
self,
image_embeddings: Tensor,
point_coords: Tensor,
point_labels: Tensor,
mask_input: Tensor,
has_mask_input: Tensor,
ori_shape: Tensor,
) -> tuple[Tensor, ...]:
"""Forward method for SAM inference (export/deploy).
Args:
image_embeddings (Tensor): The image embedding with a batch index of length 1.
If it is a zero tensor, the image embedding will be computed from the image.
point_coords (Tensor): Coordinates of sparse input prompts,
corresponding to both point inputs and box inputs.
Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner.
Coordinates must already be transformed to long-side 1024. Has a batch index of length 1.
point_labels (Tensor): Labels for the sparse input prompts.
0 is a negative input point, 1 is a positive input point,
2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point.
If there is no box input, a single padding point with label -1 and
coordinates (0.0, 0.0) should be concatenated.
mask_input (Tensor): A mask input to the model with shape 1x1x256x256.
This must be supplied even if there is no mask input. In this case, it can just be zeros.
has_mask_input (Tensor): An indicator for the mask input.
1 indicates a mask input, 0 indicates no mask input.
This input has 1x1 shape due to supporting openvino input layout.
ori_shape (Tensor): The size of the input image in (H,W) format, before any transformation.
This input has 1x2 shape due to supporting openvino input layout.
"""
return self.model.forward_for_tracing(
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)
[docs]
class SAM(CommonSettingMixin, OTXVisualPromptingModel): # type: ignore[misc]
"""OTX visual prompting model class for Segment Anything Model (SAM)."""
input_size_multiplier = 16
def __init__(
self,
backbone_type: Literal["tiny_vit", "vit_b"],
label_info: LabelInfoTypes = NullLabelInfo(),
input_size: tuple[int, int] = (1024, 1024),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = VisualPromptingMetricCallable,
torch_compile: bool = False,
freeze_image_encoder: bool = True,
freeze_prompt_encoder: bool = True,
freeze_mask_decoder: bool = False,
use_stability_score: bool = False,
return_single_mask: bool = True,
return_extra_metrics: bool = False,
stability_score_offset: float = 1.0,
) -> None:
if input_size[0] != input_size[1]:
msg = f"SAM should use square image size, but got {input_size}"
raise ValueError(msg)
self.backbone_type = backbone_type
self.image_size = input_size[0]
self.image_embedding_size = input_size[0] // 16
self.use_stability_score = use_stability_score
self.return_single_mask = return_single_mask
self.return_extra_metrics = return_extra_metrics
self.stability_score_offset = stability_score_offset
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
self.load_state_dict(load_from=self.load_from[backbone_type])
self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder)
def _build_model(self) -> nn.Module:
image_encoder = SAMImageEncoder(backbone_type=self.backbone_type, img_size=self.image_size)
prompt_encoder = SAMPromptEncoder(
image_embedding_size=(self.image_embedding_size, self.image_embedding_size),
input_image_size=(self.image_size, self.image_size),
)
mask_decoder = SAMMaskDecoder()
criterion = SAMCriterion(image_size=self.image_size)
return SegmentAnything(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
criterion=criterion,
image_size=self.image_size,
use_stability_score=self.use_stability_score,
return_single_mask=self.return_single_mask,
return_extra_metrics=self.return_extra_metrics,
stability_score_offset=self.stability_score_offset,
)
[docs]
class ZeroShotSAM(CommonSettingMixin, OTXZeroShotVisualPromptingModel): # type: ignore[misc]
"""Zero-Shot Visual Prompting model."""
def __init__( # noqa: PLR0913
self,
backbone_type: Literal["tiny_vit", "vit_b"],
label_info: LabelInfoTypes = NullLabelInfo(),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = VisualPromptingMetricCallable,
torch_compile: bool = False,
reference_info_dir: Path | str = "reference_infos",
infer_reference_info_root: Path | str = "../.latest/train",
save_outputs: bool = True,
pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006
pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006
freeze_image_encoder: bool = True,
freeze_prompt_encoder: bool = True,
freeze_mask_decoder: bool = True,
default_threshold_reference: float = 0.3,
default_threshold_target: float = 0.65,
use_stability_score: bool = False,
return_single_mask: bool = False,
return_extra_metrics: bool = False,
stability_score_offset: float = 1.0,
) -> None:
self.backbone_type = backbone_type
self.image_size = 1024 # zero-shot visual prompting model uses fixed 1024x1024 input size
self.image_embedding_size = 1024 // 16 # zero-shot visual prompting model uses fixed 1024x1024 input size
self.default_threshold_reference = default_threshold_reference
self.default_threshold_target = default_threshold_target
self.use_stability_score = use_stability_score
self.return_single_mask = return_single_mask
self.return_extra_metrics = return_extra_metrics
self.stability_score_offset = stability_score_offset
super().__init__(
label_info=label_info,
input_size=(1024, 1024), # zero-shot visual prompting model uses fixed 1024x1024 input size
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
# check freeze conditions
if not (freeze_image_encoder and freeze_prompt_encoder and freeze_mask_decoder):
log.warning(
"All of freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder "
"must be set to True, changed.",
)
freeze_image_encoder = True
freeze_prompt_encoder = True
freeze_mask_decoder = True
self.load_state_dict(load_from=self.load_from[backbone_type])
self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder)
self.save_outputs = save_outputs
self.reference_info_dir: Path = Path(reference_info_dir)
self.infer_reference_info_root: Path = Path(infer_reference_info_root)
self.register_buffer("pixel_mean", Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", Tensor(pixel_std).view(-1, 1, 1), False)
self.initialize_reference_info()
def _build_model(self) -> nn.Module:
image_encoder = SAMImageEncoder(backbone_type=self.backbone_type, img_size=self.image_size)
prompt_encoder = SAMPromptEncoder(
image_embedding_size=(self.image_embedding_size, self.image_embedding_size),
input_image_size=(self.image_size, self.image_size),
)
mask_decoder = SAMMaskDecoder()
criterion = None
return ZeroShotSegmentAnything(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
criterion=criterion,
image_size=self.image_size,
default_threshold_reference=self.default_threshold_reference,
default_threshold_target=self.default_threshold_target,
use_stability_score=self.use_stability_score,
return_single_mask=self.return_single_mask,
return_extra_metrics=self.return_extra_metrics,
stability_score_offset=self.stability_score_offset,
)
[docs]
def forward( # type: ignore[override]
self,
inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override]
) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity:
"""Model forward function."""
forward_fn = self.learn if self.training else self.infer
return forward_fn(inputs) # type: ignore[operator]
[docs]
def learn(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity,
reference_feats: Tensor | None = None,
used_indices: Tensor | None = None,
reset_feat: bool = False,
is_cascade: bool = False,
) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity:
"""Learn to directly connect to the model."""
self.training = True
if reset_feat:
self.initialize_reference_info()
outputs = self.model.learn(
**self._customize_inputs(inputs, reference_feats=reference_feats, used_indices=used_indices),
is_cascade=is_cascade,
)
return self._customize_outputs(outputs, inputs)
[docs]
def infer(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity,
reference_feats: Tensor | None = None,
used_indices: Tensor | None = None,
threshold: float = 0.0,
num_bg_points: int = 1,
is_cascade: bool = True,
) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity:
"""Infer to directly connect to the model."""
self.training = False
outputs = self.model.infer(
**self._customize_inputs(inputs, reference_feats=reference_feats, used_indices=used_indices),
threshold=threshold,
num_bg_points=num_bg_points,
is_cascade=is_cascade,
)
return self._customize_outputs(outputs, inputs)
def _gather_prompts_with_labels(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity,
) -> list[dict[int, list[ZeroShotPromptType]]]:
"""Gather prompts according to labels."""
total_processed_prompts: list[dict[int, list[ZeroShotPromptType]]] = []
for prompts, labels in zip(inputs.prompts, inputs.labels):
processed_prompts = defaultdict(list)
for prompt, label in zip(prompts, labels):
processed_prompts[int(label)].append(prompt)
sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x))
total_processed_prompts.append(sorted_processed_prompts)
return total_processed_prompts
[docs]
def apply_image(self, image: Image | np.ndarray, target_length: int = 1024) -> Image:
"""Preprocess image to be used in the model."""
h, w = image.shape[-2:]
target_size = self.get_preprocess_shape(h, w, target_length)
return tvt_v2.functional.resize(tvt_v2.functional.to_image(image), target_size, antialias=True)
[docs]
def apply_coords(self, coords: Tensor, ori_shape: tuple[int, ...], target_length: int = 1024) -> Tensor:
"""Preprocess points to be used in the model."""
old_h, old_w = ori_shape
new_h, new_w = self.get_preprocess_shape(ori_shape[0], ori_shape[1], target_length)
coords = deepcopy(coords).to(torch.float32)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
[docs]
def apply_points(self, points: Points, ori_shape: tuple[int, ...], target_length: int = 1024) -> Points:
"""Preprocess points to be used in the model."""
return Points(self.apply_coords(points, ori_shape, target_length), canvas_size=(target_length, target_length))
[docs]
def apply_boxes(self, boxes: BoundingBoxes, ori_shape: tuple[int, ...], target_length: int = 1024) -> BoundingBoxes:
"""Preprocess boxes to be used in the model."""
return BoundingBoxes(
self.apply_coords(boxes.reshape(-1, 2, 2), ori_shape, target_length).reshape(-1, 4),
format=boxes.format,
canvas_size=(target_length, target_length),
)
[docs]
def apply_prompts(
self,
prompts: list[ZeroShotPromptType],
ori_shape: tuple[int, ...],
target_length: int = 1024,
) -> list[ZeroShotPromptType]:
"""Preprocess prompts to be used in the model."""
transformed_prompts: list[ZeroShotPromptType] = []
for prompt in prompts:
if isinstance(prompt, Points):
transformed_prompts.append(self.apply_points(prompt, ori_shape, target_length))
elif isinstance(prompt, BoundingBoxes):
transformed_prompts.append(self.apply_boxes(prompt, ori_shape, target_length))
else:
transformed_prompts.append(prompt)
return transformed_prompts
[docs]
def get_preprocess_shape(self, oldh: int, oldw: int, target_length: int) -> tuple[int, int]:
"""Get preprocess shape."""
scale = target_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
[docs]
def preprocess(self, x: Image) -> Image:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
x = self.model.pad_to_square(x)
return Image(x)
[docs]
def initialize_reference_info(self) -> None:
"""Initialize reference information."""
self.register_buffer("reference_feats", torch.zeros(0, 1, self.model.prompt_encoder.embed_dim), False)
self.register_buffer("used_indices", torch.tensor([], dtype=torch.int64), False)
[docs]
def save_reference_info(self, default_root_dir: Path | str) -> None:
"""Save reference info."""
reference_info = {
"reference_feats": self.reference_feats,
"used_indices": self.used_indices,
}
# save reference info
self.saved_reference_info_path: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
self.saved_reference_info_path.parent.mkdir(parents=True, exist_ok=True)
# TODO (sungchul): ticket no. 139210
torch.save(reference_info, self.saved_reference_info_path)
pickle.dump(
{k: v.numpy() for k, v in reference_info.items()},
self.saved_reference_info_path.with_suffix(".pickle").open("wb"),
)
log.info(f"Saved reference info at {self.saved_reference_info_path}.")
[docs]
def load_reference_info(
self,
default_root_dir: Path | str,
device: str | torch.device = "cpu",
path_to_directly_load: Path | None = None,
) -> bool:
"""Load latest reference info to be used.
Args:
default_root_dir (Path | str): Default root directory to be used
when inappropriate infer_reference_info_root is given.
device (str | torch.device): Device that reference infos will be attached.
path_to_directly_load (Path | None): Reference info path to directly be loaded.
Normally, it is obtained after `learn` which is executed when trying to do `infer`
without reference features in `on_test_start` or `on_predict_start`.
Returns:
(bool): Whether normally loading checkpoint or not.
"""
if path_to_directly_load is not None:
# if `path_to_directly_load` is given, forcely load
reference_info = torch.load(path_to_directly_load)
retval = True
log.info(f"reference info saved at {path_to_directly_load} was successfully loaded.")
else:
if str(self.infer_reference_info_root) == "../.latest/train":
# for default setting
path_reference_info = (
Path(default_root_dir)
/ self.infer_reference_info_root
/ self.reference_info_dir
/ "reference_info.pt"
)
else:
# for user input
path_reference_info = self.infer_reference_info_root / self.reference_info_dir / "reference_info.pt"
if path_reference_info.is_file():
reference_info = torch.load(path_reference_info)
retval = True
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
else:
reference_info = {}
retval = False
self.register_buffer(
"reference_feats",
reference_info.get("reference_feats", torch.zeros(0, 1, self.model.prompt_encoder.embed_dim)).to(device),
False,
)
self.register_buffer(
"used_indices",
reference_info.get("used_indices", torch.tensor([], dtype=torch.int64)).to(device),
False,
)
return retval