# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Segment Anything model for the OTX visual prompting."""
from __future__ import annotations
import logging as log
from typing import TYPE_CHECKING, Any, ClassVar, Literal
import torch
from torch import Tensor, nn
from otx.algo.visual_prompting.decoders import SAMMaskDecoder
from otx.algo.visual_prompting.encoders import SAMImageEncoder, SAMPromptEncoder
from otx.algo.visual_prompting.losses.sam_loss import SAMCriterion
from otx.algo.visual_prompting.visual_prompters import SegmentAnything
from otx.core.metrics.visual_prompting import VisualPromptingMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.visual_prompting import OTXVisualPromptingModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes, NullLabelInfo
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.core.metrics import MetricCallable
class CommonSettingMixin:
"""Mixin class for common settings in SAM.
Attributes:
model (nn.Module): The model used in SAM.
load_from (ClassVar[dict[str, str]]): A dictionary containing the URLs to load checkpoints from.
"""
model: nn.Module
load_from: ClassVar[dict[str, str]] = {
"tiny_vit": "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
def load_state_dict(
self,
state_dict: dict[str, Any] | None = None,
strict: bool = True,
assign: bool = False,
load_from: str | None = None,
) -> None:
"""Load checkpoint for SAM.
This method loads a pre-trained state dictionary for the SAM model. It can load from
a provided state dictionary or from a URL specified in the `load_from` parameter.
Args:
state_dict (dict[str, Any] | None, optional): The state dictionary to load.
Defaults to None.
strict (bool, optional): Whether to strictly enforce that the keys in state_dict
match the keys returned by this module's state_dict() function. Defaults to True.
assign (bool, optional): Whether to copy parameters instead of moving them.
Defaults to False.
load_from (str | None, optional): URL to load the checkpoint from. If provided,
this will be used instead of the state_dict argument. Defaults to None.
Raises:
ValueError: If the checkpoint format is not desirable for torch.hub.load_state_dict_from_url.
Note:
If loading from a URL, some keys are removed from the loaded state dictionary
and a 'model.' prefix is added to all remaining keys.
"""
try:
if load_from is not None:
_state_dict: dict[str, Any] = torch.hub.load_state_dict_from_url(str(load_from))
for key in [
"image_encoder.norm_head.weight",
"image_encoder.norm_head.bias",
"image_encoder.head.weight",
"image_encoder.head.bias",
]:
if key in _state_dict:
_state_dict.pop(key)
# add prefix 'model.' to all keys
for key in list(_state_dict.keys()):
_state_dict["model." + key] = _state_dict.pop(key)
state_dict = _state_dict
super().load_state_dict(state_dict, strict, assign) # type: ignore[misc]
except (ValueError, RuntimeError) as e:
log.info(
f"{e}: {load_from} is not desirable format for torch.hub.load_state_dict_from_url. "
f"To manually load {load_from}, try to set it to trainer.checkpoint.",
)
def freeze_networks(
self,
freeze_image_encoder: bool,
freeze_prompt_encoder: bool,
freeze_mask_decoder: bool,
) -> None:
"""Freeze networks depending on config.
Args:
freeze_image_encoder (bool): Whether to freeze the image encoder.
freeze_prompt_encoder (bool): Whether to freeze the prompt encoder.
freeze_mask_decoder (bool): Whether to freeze the mask decoder.
"""
for param in self.model.image_encoder.parameters():
param.requires_grad = not freeze_image_encoder
for param in self.model.prompt_encoder.parameters():
param.requires_grad = not freeze_prompt_encoder
for param in self.model.mask_decoder.parameters():
param.requires_grad = not freeze_mask_decoder
@torch.no_grad()
def forward_for_tracing(
self,
image_embeddings: Tensor,
point_coords: Tensor,
point_labels: Tensor,
mask_input: Tensor,
has_mask_input: Tensor,
ori_shape: Tensor,
) -> tuple[Tensor, ...]:
"""Forward method for SAM inference (export/deploy).
Args:
image_embeddings (Tensor): The image embedding with a batch index of length 1.
If it is a zero tensor, the image embedding will be computed from the image.
point_coords (Tensor): Coordinates of sparse input prompts,
corresponding to both point inputs and box inputs.
Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner.
Coordinates must already be transformed to long-side 1024. Has a batch index of length 1.
point_labels (Tensor): Labels for the sparse input prompts.
0 is a negative input point, 1 is a positive input point,
2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point.
If there is no box input, a single padding point with label -1 and
coordinates (0.0, 0.0) should be concatenated.
mask_input (Tensor): A mask input to the model with shape 1x1x256x256.
This must be supplied even if there is no mask input. In this case, it can just be zeros.
has_mask_input (Tensor): An indicator for the mask input.
1 indicates a mask input, 0 indicates no mask input.
This input has 1x1 shape due to supporting openvino input layout.
ori_shape (Tensor): The size of the input image in (H,W) format, before any transformation.
This input has 1x2 shape due to supporting openvino input layout.
"""
return self.model.forward_for_tracing(
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)
[docs]
class SAM(CommonSettingMixin, OTXVisualPromptingModel): # type: ignore[misc]
"""OTX visual prompting model class for Segment Anything Model (SAM)."""
input_size_multiplier = 16
def __init__(
self,
backbone_type: Literal["tiny_vit", "vit_b"],
label_info: LabelInfoTypes = NullLabelInfo(),
input_size: tuple[int, int] = (1024, 1024),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = VisualPromptingMetricCallable,
torch_compile: bool = False,
freeze_image_encoder: bool = True,
freeze_prompt_encoder: bool = True,
freeze_mask_decoder: bool = False,
use_stability_score: bool = False,
return_single_mask: bool = True,
return_extra_metrics: bool = False,
stability_score_offset: float = 1.0,
) -> None:
if input_size[0] != input_size[1]:
msg = f"SAM should use square image size, but got {input_size}"
raise ValueError(msg)
self.backbone_type = backbone_type
self.image_size = input_size[0]
self.image_embedding_size = input_size[0] // 16
self.use_stability_score = use_stability_score
self.return_single_mask = return_single_mask
self.return_extra_metrics = return_extra_metrics
self.stability_score_offset = stability_score_offset
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)
self.load_state_dict(load_from=self.load_from[backbone_type])
self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder)
def _build_model(self) -> nn.Module:
image_encoder = SAMImageEncoder(backbone_type=self.backbone_type, img_size=self.image_size)
prompt_encoder = SAMPromptEncoder(
image_embedding_size=(self.image_embedding_size, self.image_embedding_size),
input_image_size=(self.image_size, self.image_size),
)
mask_decoder = SAMMaskDecoder()
criterion = SAMCriterion(image_size=self.image_size)
return SegmentAnything(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
criterion=criterion,
image_size=self.image_size,
use_stability_score=self.use_stability_score,
return_single_mask=self.return_single_mask,
return_extra_metrics=self.return_extra_metrics,
stability_score_offset=self.stability_score_offset,
)