Source code for otx.core.utils.utils

# Copyright (c) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Utility functions."""

from __future__ import annotations

import importlib
from collections import defaultdict
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any

import torch
from datumaro.components.annotation import AnnotationType, LabelCategories

from otx.utils.device import is_xpu_available

if TYPE_CHECKING:
    from datumaro import Dataset as DmDataset


[docs] def is_ckpt_from_otx_v1(ckpt: dict) -> bool: """Check the checkpoint where it comes from. Args: ckpt (dict): the checkpoint file Returns: bool: True means the checkpoint comes from otx1 """ return "model" in ckpt and ckpt["VERSION"] == 1
[docs] def is_ckpt_for_finetuning(ckpt: dict) -> bool: """Check the checkpoint will be used to finetune. Args: ckpt (dict): the checkpoint file Returns: bool: True means the checkpoint will be used to finetune. """ return "state_dict" in ckpt
[docs] def get_adaptive_num_workers(num_dataloader: int = 1) -> int | None: """Measure appropriate num_workers value and return it.""" num_devices = torch.xpu.device_count() if is_xpu_available() else torch.cuda.device_count() if num_devices == 0: return None return min(cpu_count() // (num_dataloader * num_devices), 8) # max available num_workers is 8
[docs] def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = False) -> dict[int | str, list[int]]: """Compute class statistics.""" stats: dict[int | str, list[int]] = defaultdict(list) labels = dm_dataset.categories().get(AnnotationType.label, LabelCategories()) for item_idx, item in enumerate(dm_dataset): for ann in item.annotations: if use_string_label: stats[labels.items[ann.label].name].append(item_idx) else: stats[ann.label].append(item_idx) # Remove duplicates in label stats idx: O(n) for k in stats: stats[k] = list(dict.fromkeys(stats[k])) return stats
[docs] def import_object_from_module(obj_path: str) -> Any: # noqa: ANN401 """Get object from import format string.""" module_name, obj_name = obj_path.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, obj_name)
[docs] def remove_state_dict_prefix(state_dict: dict[str, Any], prefix: str) -> dict[str, Any]: """Remove prefix from state_dict keys.""" new_state_dict = {} for key, value in state_dict.items(): new_key = key.replace(prefix, "") new_state_dict[new_key] = value return new_state_dict