Source code for otx.core.config.hpo

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Config objects for HPO."""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path  # noqa: TCH003
from typing import Any, Callable, Literal

import torch

from otx.utils.utils import is_xpu_available

if torch.cuda.is_available():
    num_workers = torch.cuda.device_count()
elif is_xpu_available():
    num_workers = torch.xpu.device_count()
else:
    num_workers = 1


[docs] @dataclass class HpoConfig: """DTO for HPO configuration. progress_update_callback (Callable[[int | float], None] | None): callback to update progress. If it's given, it's called with progress every second. callbacks_to_exclude (list[str] | str | None): List of name of callbacks to exclude during HPO. """ search_space: dict[str, dict[str, Any]] | str | Path | None = None save_path: str | None = None mode: Literal["max", "min"] = "max" num_trials: int | None = None num_workers: int = num_workers expected_time_ratio: int | float | None = 4 maximum_resource: int | float | None = None prior_hyper_parameters: dict | list[dict] | None = None acceptable_additional_time_ratio: float | int = 1.0 minimum_resource: int | float | None = None reduction_factor: int = 3 asynchronous_bracket: bool = True asynchronous_sha: bool = num_workers > 1 metric_name: str | None = None adapt_bs_search_space_max_val: Literal["None", "Safe", "Full"] = "None" progress_update_callback: Callable[[int | float], None] | None = None callbacks_to_exclude: list[str] | str | None = None