Source code for datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_bbox

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from typing import List, Tuple

import numpy as np

from datumaro.components.abstracts import IModelInterpreter
from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo
from datumaro.components.annotation import Annotation, Bbox, FeatureVector, Mask
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.media import Image


[docs] class SAMDecoderForBboxInterpreter(IModelInterpreter): h_model = 1024 w_model = 1024 onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32) mask_threshold = 0.0
[docs] def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: img = inp.media_as(Image).data img_embed = inp.annotations[-1] assert isinstance( img_embed, FeatureVector ), "annotations should have the image embedding vector as FeatureVector." h_img, w_img = img.shape[:2] scale = min(self.h_model / h_img, self.w_model / w_img) bboxes = [ann for ann in inp.annotations if isinstance(ann, Bbox)] onnx_coord = scale * np.array( [(bbox.x, bbox.y, bbox.x + bbox.w, bbox.y + bbox.h) for bbox in bboxes], dtype=np.float32, ).reshape(-1, 2, 2) onnx_label = np.array([2, 3] * len(onnx_coord)).reshape(-1, 2).astype(np.float32) decoder_inputs = { "image_embeddings": img_embed.vector[None, :, :, :], "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": self.onnx_mask_input, "has_mask_input": self.onnx_has_mask_input, "orig_im_size": np.array(img.shape[:2], dtype=np.float32), } return decoder_inputs, bboxes
[docs] def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: """Postprocesses the outputs of the SAM decoder to generate masks associated with bounding boxes. Parameters: pred: List of dictionaries containing model predictions. Each dictionary should have the 'masks' key corresponding to the predicted mask of which shape is (1, H, W). info: List of associated bounding boxes obtained from the preprocessing step. Returns: List of `Mask`s associated with the input bounding boxes. Each Mask object contains the generated mask along with metadata from the corresponding `Bbox`: `id`, `group`, `object_id`, `label`, `z_order` and `attributes`. """ masks = [np.squeeze(p["masks"], axis=0) > self.mask_threshold for p in pred] bboxes: List[Bbox] = info # Return from preprocess() return [ Mask( image=mask, id=bbox.id, group=bbox.group, object_id=bbox.object_id, label=bbox.label, z_order=bbox.z_order, attributes=bbox.attributes, ) for bbox, mask in zip(bboxes, masks) ]
[docs] def get_categories(self): return None