Source code for datumaro.components.algorithms.hash_key_inference.explorer

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from typing import List, Optional, Sequence, Union

import numpy as np

from datumaro.components.algorithms.hash_key_inference.base import HashInference
from datumaro.components.algorithms.hash_key_inference.hashkey_util import (
    calculate_hamming,
    select_uninferenced_dataset,
)
from datumaro.components.annotation import HashKey
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.errors import DatumaroError, MediaTypeError


[docs] class Explorer(HashInference): def __init__( self, *datasets: Sequence[Dataset], topk: int = 10, ) -> None: """ Explorer for Datumaro dataitems Parameters ---------- dataset: Datumaro dataset to explore similar dataitem. topk: Number of images. """ self._model = None self._text_model = None self._topk = topk database_keys = [] item_list = [] datasets_to_infer = [select_uninferenced_dataset(dataset) for dataset in datasets] datasets = self._compute_hash_key(datasets, datasets_to_infer) for dataset in datasets: for item in dataset: for annotation in item.annotations: if isinstance(annotation, HashKey): try: hash_key = annotation.hash_key hash_key = np.unpackbits(hash_key, axis=-1) database_keys.append(hash_key) item_list.append(item) except Exception: continue if all(i is None for i in database_keys): # media.data is None case raise ValueError("Database should have hash_key") self._database_keys = np.stack(database_keys, axis=0) self._item_list = item_list
[docs] def explore_topk( self, query: Union[DatasetItem, str, List[Union[DatasetItem, str]]], topk: Optional[int] = None, ): """ Explore topk similar results based on hamming distance for query DatasetItem """ if not topk: topk = self._topk database_keys = self._database_keys if isinstance(query, list): topk_for_query = int(topk // len(query)) * 2 if not len(query) == 1 else topk query_hash_key_list = [] result_list = [] logits_list = [] for query_ in query: if isinstance(query_, DatasetItem): query_key = self._get_hash_key_from_item_query(query_) query_hash_key_list.append(query_key) elif isinstance(query_, str): query_key = self._get_hash_key_from_text_query(query_) query_hash_key_list.append(query_key) else: raise MediaTypeError( "Unexpected media type of query '%s'. " "Expected 'DatasetItem' or 'string', actual'%s'" % (query_, type(query_)) ) for query_key in query_hash_key_list: unpacked_key = np.unpackbits(query_key.hash_key, axis=-1) logits = calculate_hamming(unpacked_key, database_keys) ind = np.argsort(logits) item_list = np.array(self._item_list)[ind] result_list.append(item_list[:topk_for_query].tolist()) logits_list.append(logits[ind][:topk_for_query].tolist()) result_list = np.stack(result_list, axis=0) logits_list = np.stack(logits_list, axis=0) flattened_indices = np.argsort(logits_list.ravel()) sorted_list = result_list.ravel()[flattened_indices] return sorted_list[:topk] if isinstance(query, DatasetItem): query_key = self._get_hash_key_from_item_query(query) elif isinstance(query, str): query_key = self._get_hash_key_from_text_query(query) else: raise MediaTypeError( "Unexpected media type of query '%s'. " "Expected 'DatasetItem' or 'string', actual'%s'" % (query, type(query)) ) if not isinstance(query_key, HashKey): # media.data is None case raise ValueError("Query should have hash_key") unpacked_key = np.unpackbits(query_key.hash_key, axis=-1) logits = calculate_hamming(unpacked_key, database_keys) ind = np.argsort(logits) item_list = np.array(self._item_list)[ind] result = item_list[:topk].tolist() return result
def _get_hash_key_from_item_query(self, query: DatasetItem) -> HashKey: """Get hash key from the `DatasetItem`. If not exists, launch the model inference to obtain it. """ query_keys_in_item = [ annotation for annotation in query.annotations if isinstance(annotation, HashKey) ] if len(query_keys_in_item) > 1: raise DatumaroError( f"There are more than two HashKey ({query_keys_in_item}) " f"in the query item ({query}). It is ambiguous!" ) if len(query_keys_in_item) == 1: return query_keys_in_item[0] return self._model.infer_item(query) def _get_hash_key_from_text_query(self, query: str) -> HashKey: return self.text_model.infer_text(query)