Source code for otx.api.utils.dataset_utils

"""Dataset utils."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import List, Optional, Tuple, Union

import numpy as np

from otx.api.entities.annotation import AnnotationSceneEntity
from otx.api.entities.dataset_item import DatasetItemEntity
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.label import LabelEntity
from otx.api.entities.model import ModelEntity
from otx.api.entities.result_media import ResultMediaEntity
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.scored_label import ScoredLabel
from otx.api.entities.shapes.rectangle import Rectangle
from otx.api.utils.vis_utils import get_actmap


[docs] def get_fully_annotated_idx(dataset: DatasetEntity) -> List[int]: """Find the indices of the fully annotated items in a dataset. A dataset item is fully annotated if local annotations are available, or if the item has the `normal` label. Args: dataset (DatasetEntity): Dataset that may contain both partially and fully annotated items. Returns: List[int]: List of indices of the fully annotated dataset items. """ local_idx = [] for idx, gt_item in enumerate(dataset): local_annotations = [ annotation for annotation in gt_item.get_annotations() if not Rectangle.is_full_box(annotation.shape) ] if not any(label.is_anomalous for label in gt_item.get_shapes_labels()) or len(local_annotations) > 0: local_idx.append(idx) return local_idx
[docs] def get_local_subset( dataset: DatasetEntity, fully_annotated_idx: Optional[List[int]] = None, include_normal: bool = True, ) -> DatasetEntity: """Extract a subset that contains only those dataset items that have local annotations. Args: dataset (DatasetEntity): Dataset from which we want to extract the locally annotated subset. fully_annotated_idx (Optional[List[int]]): The indices of the fully annotated dataset items. If not provided, the function will compute the indices before creating the subset. include_normal (bool): When true, global normal annotations will be included in the local dataset. Returns: DatasetEntity: Output dataset with only local annotations """ local_items = [] if fully_annotated_idx is None: fully_annotated_idx = get_fully_annotated_idx(dataset) for idx in fully_annotated_idx: item = dataset[idx] local_annotations = [ annotation for annotation in item.get_annotations() if not Rectangle.is_full_box(annotation.shape) ] # annotations with the normal label are considered local if include_normal: local_annotations.extend( [ annotation for annotation in item.get_annotations() if not any(label.label.is_anomalous for label in annotation.get_labels()) ] ) local_items.append( DatasetItemEntity( media=item.media, annotation_scene=AnnotationSceneEntity( local_annotations, kind=item.annotation_scene.kind, ), metadata=item.get_metadata(), subset=item.subset, roi=item.roi, ignored_labels=item.ignored_labels, ) ) return DatasetEntity(local_items, purpose=dataset.purpose)
[docs] def get_global_subset(dataset: DatasetEntity) -> DatasetEntity: """Extract a subset that contains only the global annotations. Args: dataset (DatasetEntity): Dataset from which we want to extract the globally annotated subset. Returns: DatasetEntity: Output dataset with only global annotations """ global_items = [] for item in dataset: global_annotations = [ annotation for annotation in item.get_annotations() if Rectangle.is_full_box(annotation.shape) ] global_items.append( DatasetItemEntity( media=item.media, annotation_scene=AnnotationSceneEntity(global_annotations, kind=item.annotation_scene.kind), metadata=item.get_metadata(), subset=item.subset, roi=item.roi, ignored_labels=item.ignored_labels, ) ) return DatasetEntity(global_items, purpose=dataset.purpose)
[docs] def split_local_global_dataset( dataset: DatasetEntity, ) -> Tuple[DatasetEntity, DatasetEntity]: """Split a dataset into the globally and locally annotated subsets. Args: dataset (DatasetEntity): Input dataset Returns: Tuple[DatasetEntity, DatasetEntity]: Tuple of the globally and locally annotated subsets. """ global_dataset = get_global_subset(dataset) local_dataset = get_local_subset(dataset) return global_dataset, local_dataset
[docs] def split_local_global_resultset( resultset: ResultSetEntity, ) -> Tuple[ResultSetEntity, ResultSetEntity]: """Split a resultset into the globally and locally annotated resultsets. Args: resultset (ResultSetEntity): Input resultset Returns: Tuple[ResultSetEntity, ResultSetEntity]: Tuple of the globally and locally annotated resultsets. """ global_gt_dataset = get_global_subset(resultset.ground_truth_dataset) local_gt_dataset = get_local_subset(resultset.ground_truth_dataset, include_normal=False) local_idx = get_fully_annotated_idx(resultset.ground_truth_dataset) global_pred_dataset = get_global_subset(resultset.prediction_dataset) local_pred_dataset = get_local_subset(resultset.prediction_dataset, local_idx, include_normal=False) global_resultset = ResultSetEntity( model=resultset.model, ground_truth_dataset=global_gt_dataset, prediction_dataset=global_pred_dataset, purpose=resultset.purpose, ) local_resultset = ResultSetEntity( model=resultset.model, ground_truth_dataset=local_gt_dataset, prediction_dataset=local_pred_dataset, purpose=resultset.purpose, ) return global_resultset, local_resultset
[docs] def contains_anomalous_images(dataset: DatasetEntity) -> bool: """Check if a dataset contains any items with the anomalous label. Args: dataset (DatasetEntity): Dataset to check for anomalous items. Returns: bool: True if the dataset contains anomalous items, False otherwise. """ for item in dataset: labels = item.get_shapes_labels() if any(label.is_anomalous for label in labels): return True return False
# pylint: disable-msg=too-many-locals
[docs] def add_saliency_maps_to_dataset_item( dataset_item: DatasetItemEntity, saliency_map: Union[List[Optional[np.ndarray]], np.ndarray], model: Optional[ModelEntity], labels: List[LabelEntity], predicted_scored_labels: Optional[List[ScoredLabel]] = None, explain_predicted_classes: bool = True, process_saliency_maps: bool = False, ): """Add saliency maps (2D array for class-agnostic saliency map, 3D array or list or 2D arrays for class-wise saliency maps) to a single dataset item.""" if isinstance(saliency_map, list): class_wise_saliency_map = True elif isinstance(saliency_map, np.ndarray): if saliency_map.ndim == 2: class_wise_saliency_map = False elif saliency_map.ndim == 3: class_wise_saliency_map = True else: raise ValueError(f"Saliency map has to be 2 or 3-dimensional array, " f"but got {saliency_map.ndim} dims.") else: raise TypeError("Check saliency_map, it has to be list or np.ndarray.") if class_wise_saliency_map: # Multiple saliency maps per image (class-wise saliency map), support e.g. ReciproCAM if explain_predicted_classes: # Explain only predicted classes if predicted_scored_labels is None: raise ValueError("To explain only predictions, list of predicted scored labels have to be provided.") explain_targets = set() for scored_label in predicted_scored_labels: if scored_label.label is not None: # Check for an empty label explain_targets.add(scored_label.label) else: # Explain all classes explain_targets = set(labels) for class_id, class_wise_saliency_map in enumerate(saliency_map): label = labels[class_id] if class_wise_saliency_map is not None and label in explain_targets: if process_saliency_maps: class_wise_saliency_map = get_actmap( class_wise_saliency_map, (dataset_item.width, dataset_item.height) ) saliency_media = ResultMediaEntity( name=label.name, type="saliency_map", annotation_scene=dataset_item.annotation_scene, numpy=class_wise_saliency_map, roi=dataset_item.roi, label=label, ) dataset_item.append_metadata_item(saliency_media, model=model) else: # Single saliency map per image, support e.g. ActivationMap if process_saliency_maps: saliency_map = get_actmap(saliency_map, (dataset_item.width, dataset_item.height)) saliency_media = ResultMediaEntity( name="Saliency Map", type="saliency_map", annotation_scene=dataset_item.annotation_scene, numpy=saliency_map, roi=dataset_item.roi, ) dataset_item.append_metadata_item(saliency_media, model=model)
[docs] def non_linear_normalization(saliency_map: np.ndarray) -> np.ndarray: """Use non-linear normalization y=x**1.5 for 2D saliency maps.""" min_soft_score = np.min(saliency_map) # make merged_map distribution positive to perform non-linear normalization y=x**1.5 saliency_map = (saliency_map - min_soft_score) ** 1.5 max_soft_score = np.max(saliency_map) saliency_map = 255.0 / (max_soft_score + 1e-12) * saliency_map return np.uint8(np.floor(saliency_map))