Source code for datumaro.components.algorithms.hash_key_inference.prune

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
import math
import random
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import numpy as np

from datumaro.components.algorithms.hash_key_inference.base import HashInference
from datumaro.components.algorithms.hash_key_inference.hashkey_util import (
    calculate_hamming,
    format_templates,
    select_uninferenced_dataset,
    templates,
)
from datumaro.components.annotation import HashKey, Label, LabelCategories
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem

if TYPE_CHECKING:
    import datumaro.plugins.ndr as ndr
else:
    from datumaro.util.import_util import lazy_import

    ndr = lazy_import("datumaro.plugins.ndr")


[docs] def match_num_item_for_cluster(ratio, dataset_len, cluster_num_item_list): total_num_selected_item = math.ceil(dataset_len * ratio) cluster_weights = np.array(cluster_num_item_list) / sum(cluster_num_item_list) norm_cluster_num_item_list = (cluster_weights * total_num_selected_item).astype(int) remaining_items = total_num_selected_item - sum(norm_cluster_num_item_list) if remaining_items > 0: zero_cluster_indexes = np.where(norm_cluster_num_item_list == 0)[0] add_clust_dist = np.sort(cluster_weights[zero_cluster_indexes])[::-1][:remaining_items] for dist in set(add_clust_dist): indices = np.where(cluster_weights == dist)[0] for index in indices: norm_cluster_num_item_list[index] += 1 elif remaining_items < 0: diff_num_item_list = np.argsort(cluster_weights - norm_cluster_num_item_list) for diff_idx in diff_num_item_list[: abs(remaining_items)]: norm_cluster_num_item_list[diff_idx] -= 1 return norm_cluster_num_item_list.tolist()
[docs] class PruneBase(ABC):
[docs] @abstractmethod def base( self, ratio: float, num_centers: Optional[int], labels: Optional[List[int]], database_keys: Optional[np.ndarray], item_list: List[DatasetItem], source: Optional[Dataset], ) -> Tuple[List[DatasetItem], Optional[Dict]]: """It executes each method for pruning. Parameters: ratio: How much to remain dataset after pruning. num_centers: Number of centers for clustering. labels: Label of one annotation for each datasetitem. database_keys: Batch of the numpy formatted hash_key. item_list: List of datasetitem of dataset. source: Whole dataset. Returns: It returns a tuple of selected items and distance of each item and clusters. """ raise NotImplementedError
[docs] class RandomSelect(PruneBase): """ Select items randomly from the dataset. """
[docs] def base(self, ratio, num_centers, labels, database_keys, item_list, source): random.seed(0) dataset_len = len(item_list) num_selected_item = math.ceil(dataset_len * ratio) random_indices = random.sample(range(dataset_len), num_selected_item) selected_items = [item_list[idx] for idx in random_indices] return selected_items, None
[docs] class Centroid(PruneBase): """ Select items through clustering with centers targeting the desired number. """
[docs] def base(self, ratio, num_centers, labels, database_keys, item_list, source): from sklearn.cluster import KMeans num_selected_centers = math.ceil(len(item_list) * ratio) kmeans = KMeans(n_clusters=num_selected_centers, random_state=0) clusters = kmeans.fit_predict(database_keys) cluster_centers = kmeans.cluster_centers_ cluster_ids = np.unique(clusters) selected_items = [] dist_tuples = [] for cluster_id in cluster_ids: cluster_center = cluster_centers[cluster_id] cluster_items_idx = np.where(clusters == cluster_id)[0] num_selected_items = 1 cluster_items = database_keys[cluster_items_idx,] dist = calculate_hamming(cluster_center, cluster_items) ind = np.argsort(dist) item_idx_list = cluster_items_idx[ind] for i, idx in enumerate(item_idx_list[:num_selected_items]): selected_items.append(item_list[idx]) dist_tuples.append( (cluster_id, item_list[idx].id, item_list[idx].subset, dist[ind][i]) ) return selected_items, dist_tuples
[docs] class ClusteredRandom(PruneBase): """ Select items through clustering and choose randomly within each cluster. """
[docs] def base(self, ratio, num_centers, labels, database_keys, item_list, source): from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=num_centers, random_state=0) clusters = kmeans.fit_predict(database_keys) cluster_ids, cluster_num_item_list = np.unique(clusters, return_counts=True) norm_cluster_num_item_list = match_num_item_for_cluster( ratio, len(database_keys), cluster_num_item_list ) selected_items = [] random.seed(0) for i, cluster_id in enumerate(cluster_ids): cluster_items_idx = np.where(clusters == cluster_id)[0] num_selected_items = norm_cluster_num_item_list[i] random.shuffle(cluster_items_idx) selected_items.extend(item_list[idx] for idx in cluster_items_idx[:num_selected_items]) return selected_items, None
[docs] class QueryClust(PruneBase): """ Select items through clustering with inits that imply each label. """
[docs] def base(self, ratio, num_centers, labels, database_keys, item_list, source): from sklearn.cluster import KMeans center_dict = {i: None for i in range(1, num_centers)} for item in item_list: for anno in item.annotations: if isinstance(anno, Label): label_ = anno.label if center_dict.get(label_) is None: center_dict[label_] = item if all(center_dict.values()): break item_id_list = [item.id.split("/")[-1] for item in item_list] centroids = [ database_keys[item_id_list.index(item.id)] for item in center_dict.values() if item ] kmeans = KMeans( n_clusters=num_centers, n_init=1, init=np.stack(centroids, axis=0), random_state=0 ) clusters = kmeans.fit_predict(database_keys) cluster_centers = kmeans.cluster_centers_ cluster_ids, cluster_num_item_list = np.unique(clusters, return_counts=True) norm_cluster_num_item_list = match_num_item_for_cluster( ratio, len(database_keys), cluster_num_item_list ) selected_items = [] dist_tuples = [] for i, cluster_id in enumerate(cluster_ids): cluster_center = cluster_centers[cluster_id] cluster_items_idx = np.where(clusters == cluster_id)[0] num_selected_item = norm_cluster_num_item_list[i] cluster_items = database_keys[cluster_items_idx] dist = calculate_hamming(cluster_center, cluster_items) ind = np.argsort(dist) item_idx_list = cluster_items_idx[ind] for i, idx in enumerate(item_idx_list[:num_selected_item]): selected_items.append(item_list[idx]) dist_tuples.append( (cluster_id, item_list[idx].id, item_list[idx].subset, dist[ind][i]) ) return selected_items, dist_tuples
[docs] class Entropy(PruneBase): """ Select items through clustering and choose them based on label entropy in each cluster. """
[docs] def base(self, ratio, num_centers, labels, database_keys, item_list, source): from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=num_centers, random_state=0) clusters = kmeans.fit_predict(database_keys) cluster_ids, cluster_num_item_list = np.unique(clusters, return_counts=True) norm_cluster_num_item_list = match_num_item_for_cluster( ratio, len(database_keys), cluster_num_item_list ) selected_item_indexes = [] for cluster_id, num_selected_item in zip(cluster_ids, norm_cluster_num_item_list): cluster_items_idx = np.where(clusters == cluster_id)[0] cluster_classes = np.array(labels)[cluster_items_idx] _, inv, cnts = np.unique(cluster_classes, return_inverse=True, return_counts=True) weights = 1 / cnts probs = weights[inv] probs /= probs.sum() choices = np.random.choice(len(inv), size=num_selected_item, p=probs, replace=False) selected_item_indexes.extend(cluster_items_idx[choices]) selected_items = np.array(item_list)[selected_item_indexes].tolist() return selected_items, None
[docs] class NDRSelect(PruneBase): """ Select items based on NDR among each subset. """
[docs] def base(self, ratio, num_centers, labels, database_keys, item_list, source): subset_lists = list(source.subsets().keys()) selected_items = [] for subset_ in subset_lists: subset_len = len(source.get_subset(subset_)) num_selected_subset_item = math.ceil(subset_len * (1 - ratio)) ndr_result = ndr.NDR(source, working_subset=subset_, num_cut=num_selected_subset_item) selected_items.extend(ndr_result.get_subset(subset_)) return selected_items, None
[docs] class Prune(HashInference): def __init__( self, dataset: Dataset, cluster_method: str = "random", hash_type: str = "img", ) -> None: """ Prune make a representative and manageable subset. """ self._dataset = dataset self._cluster_method = cluster_method self._hash_type = hash_type self._model = None self._text_model = None self._num_centers = None self._database_keys = None self._item_list = [] self._labels = [] self._prepare_data() def _prepare_data(self): if self._hash_type == "txt": category_dict = self._prompting() if self._cluster_method == "random": self._item_list = list(self._dataset) return datasets_to_infer = select_uninferenced_dataset(self._dataset) datasets = self._compute_hash_key([self._dataset], [datasets_to_infer])[0] for category in datasets.categories().values(): if isinstance(category, LabelCategories): self._num_centers = len(category._indices.keys()) for item in datasets: for annotation in item.annotations: if isinstance(annotation, Label): self._labels.append(annotation.label) if isinstance(annotation, HashKey): hash_key = annotation.hash_key if self._hash_type == "txt": inputs = category_dict.get(str(item.annotations[0].label)) if isinstance(inputs, List): inputs = " ".join(inputs) hash_key_txt = self.text_model.infer_text(inputs).hash_key hash_key = np.concatenate([hash_key, hash_key_txt]) hash_key = np.unpackbits(hash_key, axis=-1) if self._database_keys is None: self._database_keys = hash_key.reshape(1, -1) else: self._database_keys = np.concatenate( (self._database_keys, hash_key.reshape(1, -1)), axis=0 ) self._item_list.append(item) def _prompting(self): category_dict = {} detected_format = self._dataset.format template = format_templates.get(detected_format, templates) for label in list(self._dataset.categories().values())[0]._indices.keys(): category_dict[label] = [temp.format(label) for temp in template] return category_dict
[docs] def get_pruned(self, ratio: float = 0.5) -> Dataset: method = { "random": RandomSelect, "cluster_random": ClusteredRandom, "centroid": Centroid, "query_clust": QueryClust, "entropy": Entropy, "ndr": NDRSelect, } prune_method = method[self._cluster_method]() selected_items, dist_tuples = prune_method.base( ratio=ratio, num_centers=self._num_centers, labels=self._labels, database_keys=self._database_keys, item_list=self._item_list, source=self._dataset, ) result_dataset = Dataset( media_type=self._dataset.media_type(), ann_types=self._dataset.ann_types() ) result_dataset._source_path = self._dataset._source_path result_dataset.define_categories(self._dataset.categories()) for item in selected_items: result_dataset.put(item) if dist_tuples: for center, id_, subset_, d in dist_tuples: log.info(f"item {id_} of subset {subset_} has distance {d} for cluster {center}") log.info(f"Pruned dataset with {ratio} from {len(self._dataset)} to {len(result_dataset)}") return result_dataset