Source code for openvino_xai.methods.white_box.recipro_cam

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from abc import abstractmethod
from typing import Callable, List, Tuple

import numpy as np
import openvino.runtime as ov
from openvino.runtime import opset10 as opset

from openvino_xai.common.utils import IdentityPreprocessFN
from openvino_xai.inserter.model_parser import IRParserCls, ModelType
from openvino_xai.methods.white_box.base import WhiteBoxMethod
from openvino_xai.utils.torch import torch


[docs] class FeatureMapPerturbationBase(WhiteBoxMethod): """ Base class for FeatureMapPerturbation-based methods. :param model: OpenVINO model. :type model: ov.Model :param preprocess_fn: Preprocessing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] :parameter target_layer: Target layer (node) name after which the XAI branch will be inserted. :type target_layer: str :param embed_scaling: Whether to scale output or not. :type embed_scaling: bool :param device_name: Device type name. :type device_name: str """ def __init__( self, model: ov.Model, preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), target_layer: str | None = None, embed_scaling: bool = True, device_name: str = "CPU", ): super().__init__(model=model, preprocess_fn=preprocess_fn, embed_scaling=embed_scaling, device_name=device_name) self.per_class = True self._target_layer = target_layer
[docs] def generate_xai_branch(self) -> ov.Node: """Implements FeatureMapPerturbation-based XAI method.""" model_clone = self._model_ori.clone() self._propagate_dynamic_batch_dimension(model_clone) saliency_maps = self._get_saliency_map(model_clone) if self.embed_scaling: saliency_maps = self._scale_saliency_maps(saliency_maps, self.per_class) return saliency_maps
@abstractmethod def _get_saliency_map(self, model_clone: ov.Model): raise NotImplementedError
[docs] class ReciproCAM(FeatureMapPerturbationBase): """ Implements Recipro-CAM for CNN models. :param model: OpenVINO model. :type model: ov.Model :param preprocess_fn: Preprocessing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] :parameter target_layer: Target layer (node) name after which the XAI branch will be inserted. :type target_layer: str :param embed_scaling: Whether to scale output or not. :type embed_scaling: bool :param device_name: Device type name. :type device_name: str :param prepare_model: Loading (compiling) the model prior to inference. :type prepare_model: bool """ def __new__( cls, model: ov.Model | torch.nn.Module | None = None, *args, **kwargs, ): if isinstance(model, torch.nn.Module): from .torch import TorchReciproCAM return TorchReciproCAM(model, *args, **kwargs) return super().__new__(cls) def __init__( self, model: ov.Model, preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), target_layer: str | None = None, embed_scaling: bool = True, device_name: str = "CPU", prepare_model: bool = True, ): super().__init__( model=model, preprocess_fn=preprocess_fn, target_layer=target_layer, embed_scaling=embed_scaling, device_name=device_name, ) self.model_type = ModelType.CNN if prepare_model: self.prepare_model() def _get_saliency_map(self, model_clone: ov.Model) -> ov.Node: target_node_ori = IRParserCls.get_target_node(self._model_ori, self.model_type, self._target_layer) target_node_name = self._target_layer or target_node_ori.get_friendly_name() post_target_node_clone = IRParserCls.get_post_target_node(model_clone, self.model_type, target_node_name) logit_node = IRParserCls.get_logit_node(self._model_ori, search_softmax=True) logit_node_clone_model = IRParserCls.get_logit_node(model_clone, search_softmax=True) if not logit_node_clone_model.output(0).partial_shape[0].is_dynamic: raise ValueError( "Batch shape of the output should be dynamic, but it is static. " "Make sure that the dynamic inputs can propagate through the model graph." ) _, c, h, w = target_node_ori.get_output_partial_shape(0) c, h, w = c.get_length(), h.get_length(), w.get_length() if not self._is_valid_layout(c, h, w): raise ValueError(f"ReciproCAM supports only NCHW layout, but got NHWC, with shape: [N, {c}, {h}, {w}]") feature_map_repeated = opset.tile(target_node_ori.output(0), (h * w, 1, 1, 1)) mosaic_feature_map_mask = np.zeros((h * w, c, h, w), dtype=np.float32) tmp = np.arange(h * w) spacial_order = np.reshape(tmp, (h, w)) for i in range(h): for j in range(w): k = spacial_order[i, j] mosaic_feature_map_mask[k, :, i, j] = np.ones((c)) mosaic_feature_map_mask = opset.constant(mosaic_feature_map_mask) mosaic_feature_map = opset.multiply(feature_map_repeated, mosaic_feature_map_mask) for node in post_target_node_clone: node.input(0).replace_source_output(mosaic_feature_map.output(0)) mosaic_prediction = logit_node_clone_model tmp = opset.transpose(mosaic_prediction.output(0), (1, 0)) _, num_classes = logit_node.get_output_partial_shape(0) saliency_maps = opset.reshape(tmp, (1, num_classes.get_length(), h, w), False) return saliency_maps @staticmethod def _is_valid_layout(c: int, h: int, w: int): return h < c and w < c
[docs] class ViTReciproCAM(FeatureMapPerturbationBase): """ Implements ViTRecipro-CAM for transformer models. :param model: OpenVINO model. :type model: ov.Model :param preprocess_fn: Preprocessing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] :parameter target_layer: Target layer (node) name after which the XAI branch will be inserted. :type target_layer: str :param embed_scaling: Whether to scale output or not. :type embed_scaling: bool :param device_name: Device type name. :type device_name: str :param use_gaussian: Whether to use Gaussian for mask generation or not. :type use_gaussian: bool :param cls_token: Whether to use cls token for mosaic prediction or not. :type cls_token: bool :param final_norm: Whether the model has normalization after the last transformer block. :type final_norm: bool :param k: Count of the transformer block (from head) before which XAI branch will be inserted, 1-indexed. :type k: int :param prepare_model: Loading (compiling) the model prior to inference. :type prepare_model: bool """ def __new__( cls, model: ov.Model | torch.nn.Module | None = None, *args, **kwargs, ): if isinstance(model, torch.nn.Module): from .torch import TorchViTReciproCAM return TorchViTReciproCAM(model, *args, **kwargs) return super().__new__(cls) def __init__( self, model: ov.Model, preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), target_layer: str | None = None, embed_scaling: bool = True, device_name: str = "CPU", use_gaussian: bool = True, cls_token: bool = True, final_norm: bool = True, k: int = 1, prepare_model: bool = True, ): super().__init__( model=model, preprocess_fn=preprocess_fn, target_layer=target_layer, embed_scaling=embed_scaling, device_name=device_name, ) self.model_type = ModelType.TRANSFORMER self._use_gaussian = use_gaussian self._cls_token = cls_token # Count of target "Add" node (between the blocks), from the output, 1-indexed self._k = k * 2 + int(final_norm) if prepare_model: self.prepare_model() def _get_saliency_map(self, model_clone: ov.Model) -> ov.Node: # Add -> add node before the target transformer blocks # ↓ ↓ # skip block -> skip connection to the next block and target block itself # ↓ ↓ # Add -> add node after the target transformer blocks # Get target Add node in-between the transformer blocks target_node_ori = IRParserCls.get_target_node(self._model_ori, self.model_type, self._target_layer, self._k) target_node_name = self._target_layer or target_node_ori.get_friendly_name() # Get post-add nodes and check them post_target_node_clone = IRParserCls.get_post_target_node(model_clone, self.model_type, target_node_name) self._post_add_node_check(post_target_node_clone) # Get logit nodes. Check them and retrieve info logit_node = IRParserCls.get_logit_node(self._model_ori, search_softmax=False) logit_node_clone = IRParserCls.get_logit_node(model_clone, search_softmax=False) if not logit_node_clone.output(0).partial_shape[0].is_dynamic: raise ValueError( "Batch shape of the output should be dynamic, but it is static. " "Make sure that the dynamic inputs can propagate through the model graph." ) _, num_classes = logit_node.get_output_partial_shape(0) dim, h, w, num_aux_tokens = self._get_internal_size(target_node_ori) # Depth first search till the end of the LayerNorm (traverse of the block branch) post_target_node_ori = IRParserCls.get_post_target_node(self._model_ori, self.model_type, target_node_name) norm_node_ori = self._get_non_add_node_from_two_nodes(post_target_node_ori) while norm_node_ori.get_type_name() != "Add": if len(norm_node_ori.outputs()) > 1: raise ValueError("Number of normalization outputs > 1!") inputs = norm_node_ori.output(0).get_target_inputs() if len(inputs) > 1: raise ValueError("Number of normalization inputs > 1!") norm_node_ori = next(iter(inputs)).get_node() # Mosaic feature map after the LayerNorm post_target_node_clone_norm = IRParserCls.get_post_target_node( model_clone, self.model_type, norm_node_ori.get_friendly_name() ) mosaic_feature_map_norm = self._get_mosaic_feature_map(norm_node_ori, dim, h, w, num_aux_tokens) for node in post_target_node_clone_norm: node.input(0).replace_source_output(mosaic_feature_map_norm.output(0)) # Mosaic feature map after the Add node mosaic_feature_map = self._get_mosaic_feature_map(target_node_ori, dim, h, w, num_aux_tokens) add_node_clone = self._get_add_node_from_two_nodes(post_target_node_clone) add_node_clone.input(0).replace_source_output(mosaic_feature_map.output(0)) # Transform mosaic predictions into the saliency map mosaic_prediction = logit_node_clone tmp = opset.transpose(mosaic_prediction.output(0), (1, 0)) saliency_maps = opset.reshape(tmp, (1, num_classes.get_length(), h, w), False) return saliency_maps def _get_internal_size(self, target_node: ov.Node) -> Tuple[int, int, int, int]: _, token_number, dim = target_node.get_output_partial_shape(0) if token_number.is_dynamic or dim.is_dynamic: first_conv_node = IRParserCls.get_first_conv_node(self._model_ori) _, dim, h, w = first_conv_node.get_output_partial_shape(0) dim, h, w = dim.get_length(), h.get_length(), w.get_length() first_concat_node = IRParserCls.get_first_concat_node(self._model_ori) num_aux_tokens = len(first_concat_node.inputs()) - 1 else: token_number, dim = token_number.get_length(), dim.get_length() h = w = int((token_number - 1) ** 0.5) num_aux_tokens = token_number - (h * w) return dim, h, w, num_aux_tokens def _get_add_node_from_two_nodes(self, node_list: List[ov.Node]) -> ov.Node: self._post_add_node_check(node_list) node1, node2 = node_list if node1.get_type_name() == "Add": return node1 return node2 def _get_non_add_node_from_two_nodes(self, node_list: List[ov.Node]) -> ov.Node: self._post_add_node_check(node_list) node1, node2 = node_list if node1.get_type_name() != "Add": return node1 return node2 @staticmethod def _post_add_node_check(node_list: List[ov.Node]): if len(node_list) != 2: raise ValueError(f"Only two outputs of the between block Add node supported, " f"but got {len(node_list)}.") node1, node2 = node_list if not (node1.get_type_name() == "Add") != (node2.get_type_name() == "Add"): raise ValueError( f"One (and only one) of the nodes has to be Add type. " f"But got {node1.get_type_name()} and {node2.get_type_name()}." ) def _get_mosaic_feature_map(self, target_node_ori: ov.Node, dim: int, h: int, w: int, num_aux_tokens: int): if self._use_gaussian: if self._cls_token: cls_token = opset.slice( target_node_ori, np.array([0]), np.array([num_aux_tokens]), np.array([1]), np.array([1]) ) else: cls_token = opset.constant(np.zeros((1, 1, dim)), dtype=np.float32) cls_token = opset.tile(cls_token.output(0), (h * w, 1, 1)) target_node_ori_wo_cls_token = opset.slice( target_node_ori, np.array([1]), np.array([h * w + 1]), np.array([1]), np.array([1]) ) feature_map_spacial = opset.reshape(target_node_ori_wo_cls_token, (1, h, w, dim), False) feature_map_spacial_repeated = opset.tile(feature_map_spacial.output(0), (h * w, 1, 1, 1)) tmp = np.arange(h * w) spacial_order = np.reshape(tmp, (h, w)) gaussian = np.array( [[1 / 16.0, 1 / 8.0, 1 / 16.0], [1 / 8.0, 1 / 4.0, 1 / 8.0], [1 / 16.0, 1 / 8.0, 1 / 16.0]] ) mosaic_feature_map_mask_padded = np.zeros((h * w, h + 2, w + 2), dtype=np.float32) for i in range(h): for j in range(w): k = spacial_order[i, j] i_pad = i + 1 j_pad = j + 1 mosaic_feature_map_mask_padded[k, i_pad - 1 : i_pad + 2, j_pad - 1 : j_pad + 2] = gaussian mosaic_feature_map_mask = mosaic_feature_map_mask_padded[:, 1:-1, 1:-1] mosaic_feature_map_mask = np.expand_dims(mosaic_feature_map_mask, 3) mosaic_feature_map_mask = opset.constant(mosaic_feature_map_mask) mosaic_feature_map_mask = opset.tile(mosaic_feature_map_mask.output(0), (1, 1, 1, dim)) mosaic_fm_wo_cls_token = opset.multiply(feature_map_spacial_repeated, mosaic_feature_map_mask) mosaic_fm_wo_cls_token = opset.reshape(mosaic_fm_wo_cls_token, (h * w, h * w, dim), False) mosaic_feature_map = opset.concat([cls_token, mosaic_fm_wo_cls_token], 1) else: mosaic_feature_map_mask_wo_cls_token = np.zeros((h * w, h * w), dtype=np.float32) for i in range(h * w): mosaic_feature_map_mask_wo_cls_token[i, i] = 1 if self._cls_token: cls_token_mask = np.ones((h * w, 1), dtype=np.float32) else: cls_token_mask = np.zeros((h * w, 1), dtype=np.float32) mosaic_feature_map_mask = np.hstack((cls_token_mask, mosaic_feature_map_mask_wo_cls_token)) mosaic_feature_map_mask = np.expand_dims(mosaic_feature_map_mask, 2) mosaic_feature_map_mask = opset.constant(mosaic_feature_map_mask) mosaic_feature_map_mask = opset.tile(mosaic_feature_map_mask.output(0), (1, 1, dim)) # e.g. 784x785x768 feature_map_repeated = opset.tile(target_node_ori.output(0), (h * w, 1, 1)) # e.g. 784x785x768 mosaic_feature_map = opset.multiply(feature_map_repeated, mosaic_feature_map_mask) return mosaic_feature_map