Source code for otx.algorithms.common.utils.utils

"""Collections of Utils for common OTX algorithms."""

# Copyright (C) 2022-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import importlib
import inspect
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
import onnx
import torch
import yaml
from addict import Dict as adict

from otx.utils.logger import get_logger
from otx.utils.utils import add_suffix_to_filename

logger = get_logger()


HPU_AVAILABLE = None
try:
    import habana_frameworks.torch as htorch
except ImportError:
    HPU_AVAILABLE = False
    htorch = None

XPU_AVAILABLE = None
try:
    import intel_extension_for_pytorch as ipex
except ImportError:
    XPU_AVAILABLE = False
    ipex = None


[docs] class UncopiableDefaultDict(defaultdict): """Defauldict type object to avoid deepcopy.""" def __deepcopy__(self, memo): """Deepcopy.""" return self
[docs] def load_template(path): """Loading model template function.""" with open(path, encoding="UTF-8") as f: template = yaml.safe_load(f) return template
[docs] def get_task_class(path: str): """Return Task classes.""" module_name, class_name = path.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, class_name)
[docs] def get_arg_spec( # noqa: C901 # pylint: disable=too-many-branches fn: Callable, # pylint: disable=invalid-name depth: Optional[int] = None, ) -> Tuple[str, ...]: """Get argument spec of function.""" args = set() cls_obj = None if inspect.ismethod(fn): fn_name = fn.__name__ cls_obj = fn.__self__ if not inspect.isclass(cls_obj): cls_obj = cls_obj.__class__ else: fn_name = fn.__name__ names = fn.__qualname__.split(".") if len(names) > 1 and names[-1] == fn_name: cls_obj = globals()[".".join(names[:-1])] if cls_obj: for obj in cls_obj.mro(): # type: ignore fn_obj = cls_obj.__dict__.get(fn_name, None) if fn_obj is not None: if isinstance(fn_obj, staticmethod): cls_obj = None break if cls_obj is None: # function, staticmethod spec = inspect.getfullargspec(fn) args.update(spec.args) else: # method, classmethod for i, obj in enumerate(cls_obj.mro()): # type: ignore if depth is not None and i == depth: break method = getattr(obj, fn_name, None) if method is None: break spec = inspect.getfullargspec(method) args.update(spec.args[1:]) if spec.varkw is None and spec.varargs is None: break return tuple(args)
[docs] def set_random_seed(seed, logger=None, deterministic=False): """Set random seed. Args: seed (int): Seed to be used. logger (logging.Logger): logger for logging seed info deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Default: False. """ import torch random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if is_xpu_available(): torch.xpu.manual_seed_all(seed) os.environ["PYTHONHASHSEED"] = str(seed) if logger: logger.info(f"Training seed was set to {seed} w/ deterministic={deterministic}.") if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs] def get_default_async_reqs_num() -> int: """Returns a default number of infer request for OV models.""" reqs_num = os.cpu_count() if reqs_num is not None: reqs_num = max(1, int(reqs_num / 2)) return reqs_num else: return 1
[docs] def read_py_config(filename: str) -> adict: """Reads py config to a dict.""" filename = str(Path(filename).resolve()) if not Path(filename).is_file: raise RuntimeError("config not found") assert filename.endswith(".py") module_name = Path(filename).stem if "." in module_name: raise ValueError("Dots are not allowed in config file path.") config_dir = Path(filename).parent sys.path.insert(0, str(config_dir)) mod = importlib.import_module(module_name) sys.path.pop(0) cfg_dict = adict( { name: value for name, value in mod.__dict__.items() if not name.startswith("__") and not inspect.isclass(value) and not inspect.ismodule(value) } ) return cfg_dict
def embed_onnx_model_data(onnx_file: str, extra_model_data: Dict[Tuple[str, str], Any]) -> None: """Embeds model api config to onnx file.""" model = onnx.load(onnx_file) for item in extra_model_data: meta = model.metadata_props.add() attr_path = " ".join(map(str, item)) meta.key = attr_path.strip() meta.value = str(extra_model_data[item]) onnx.save(model, onnx_file)
[docs] def is_xpu_available() -> bool: """Checks if XPU device is available.""" global XPU_AVAILABLE # noqa: PLW0603 if XPU_AVAILABLE is None: XPU_AVAILABLE = hasattr(torch, "xpu") and torch.xpu.is_available() return XPU_AVAILABLE
[docs] def is_hpu_available() -> bool: """Check if HPU device is available.""" global HPU_AVAILABLE # noqa: PLW0603 if HPU_AVAILABLE is None: HPU_AVAILABLE = htorch.hpu.is_available() return HPU_AVAILABLE
[docs] def cast_bf16_to_fp32(tensor: torch.Tensor) -> torch.Tensor: """Cast bf16 tensor to fp32 before processed by numpy. numpy doesn't support bfloat16, it is required to convert bfloat16 tensor to float32. """ if tensor.dtype == torch.bfloat16: tensor = tensor.to(torch.float32) return tensor
[docs] def get_cfg_based_on_device(cfg_file_path: Union[str, Path]) -> str: """Find a config file according to device.""" if is_xpu_available(): cfg_for_device = add_suffix_to_filename(cfg_file_path, "_xpu") if cfg_for_device.exists(): logger.info( f"XPU is detected. XPU config file will be used : {Path(cfg_file_path).name} -> {cfg_for_device.name}" ) cfg_file_path = cfg_for_device return str(cfg_file_path)