Source code for datumaro.components.algorithms.rise

# Copyright (C) 2019-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

# pylint: disable=unused-variable

import cv2
import numpy as np

from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.media import Image
from datumaro.util import take_by

__all__ = ["RISE"]


[docs] class RISE: """ Implements RISE: Randomized Input Sampling for Explanation of Black-box Models algorithm. See explanations at: https://arxiv.org/pdf/1806.07421.pdf """ def __init__( self, model, num_masks: int = 100, mask_size: int = 7, prob: float = 0.5, batch_size: int = 1, ): assert prob >= 0 and prob <= 1 self.model = model self.num_masks = num_masks self.mask_size = mask_size self.prob = prob self.batch_size = batch_size
[docs] def normalize_saliency(self, saliency): normalized_saliency = np.empty_like(saliency) for idx, sal in enumerate(saliency): normalized_saliency[idx, ...] = (sal - np.min(sal)) / (np.max(sal) - np.min(sal)) return normalized_saliency
[docs] def generate_masks(self, image_size): cell_size = np.ceil(np.array(image_size) / self.mask_size).astype(np.int8) up_size = tuple([(self.mask_size + 1) * cs for cs in cell_size]) grid = np.random.rand(self.num_masks, self.mask_size, self.mask_size) < self.prob grid = grid.astype("float32") masks = np.empty((self.num_masks, *image_size)) for i in range(self.num_masks): # Random shifts x = np.random.randint(0, cell_size[0]) y = np.random.randint(0, cell_size[1]) # Linear upsampling and cropping masks[i, ...] = cv2.resize(grid[i], up_size, interpolation=cv2.INTER_LINEAR)[ x : x + image_size[0], y : y + image_size[1] ] return masks
[docs] def generate_masked_dataset(self, image, image_size, masks): input_image = cv2.resize(image, image_size, interpolation=cv2.INTER_LINEAR) items = [] for id, mask in enumerate(masks): masked_image = np.expand_dims(mask, axis=-1) * input_image items.append( DatasetItem( id=id, media=Image.from_numpy(masked_image), ) ) return Dataset.from_iterable(items)
[docs] def apply(self, image, progressive=False): assert len(image.shape) in [2, 3], "Expected an input image in (H, W, C) format" if len(image.shape) == 3: assert image.shape[2] in [3, 4], "Expected BGR or BGRA input" image = image[:, :, :3].astype(np.float32) model = self.model image_size = model.inputs[0].shape logit_size = model.outputs[0].shape batch_size = image_size[0] if image_size[1] in [1, 3]: # for CxHxW image_size = (image_size[2], image_size[3]) elif image_size[3] in [1, 3]: # for HxWxC image_size = (image_size[1], image_size[2]) masks = self.generate_masks(image_size=image_size) masked_dataset = self.generate_masked_dataset(image, image_size, masks) saliency = np.zeros((logit_size[1], *image_size), dtype=np.float32) for batch_id, batch in enumerate(take_by(masked_dataset, batch_size)): outputs = model.launch(batch) for sample_id in range(len(batch)): mask = masks[batch_size * batch_id + sample_id] for class_idx in range(logit_size[1]): score = outputs[sample_id][class_idx].attributes["score"] saliency[class_idx, ...] += score * mask # [TODO] wonjuleee: support DRISE for detection model explainability # if isinstance(self.target, Label): # logits = outputs[sample_id][0].vector # max_score = logits[self.target.label] # elif isinstance(self.target, Bbox): # preds = outputs[sample_id][0] # max_score = 0 # for box in preds: # if box[0] == self.target.label: # confidence, box = box[1], box[2] # score = iou(self.target.get_bbox, box) * confidence # if score > max_score: # max_score = score # saliency += max_score * mask if progressive: yield self.normalize_saliency(saliency) yield self.normalize_saliency(saliency)