# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""HPO runner and resource manager class."""
from __future__ import annotations
import logging
import multiprocessing
import os
import queue
import signal
import time
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pickle import PicklingError # nosec B403 used pickle for internal state dump/load
from typing import TYPE_CHECKING, Callable, Literal, NoReturn
from otx.core.types.device import DeviceType
from otx.hpo.hpo_base import HpoBase, Trial, TrialStatus
from otx.hpo.resource_manager import get_resource_manager
from otx.utils import append_main_proc_signal_handler
from otx.utils.utils import find_unpickleable_obj
if TYPE_CHECKING:
from collections.abc import Hashable
from signal import Signals
logger = logging.getLogger(__name__)
@dataclass
class RunningTrial:
"""Data class for a running trial."""
process: multiprocessing.Process
trial: Trial
queue: multiprocessing.Queue
class HpoLoop:
"""HPO loop manager to run trials.
Args:
hpo_algo (HpoBase): HPO algorithms.
train_func (Callable): Function to train a model.
resource_type (Literal[DeviceType.cpu, DeviceType.gpu, DeviceType.xpu], optional):
Which type of resource to use. It can be changed depending on environment. Defaults to "gpu".
num_parallel_trial (int | None, optional): How many trials to run in parallel.
It's used for CPUResourceManager. Defaults to None.
num_devices_per_trial (int, optional): Number of devices used for a single trial. Defaults to 1.
"""
def __init__(
self,
hpo_algo: HpoBase,
train_func: Callable,
resource_type: Literal[DeviceType.cpu, DeviceType.gpu, DeviceType.xpu] = DeviceType.gpu,
num_parallel_trial: int | None = None,
num_devices_per_trial: int = 1,
) -> None:
self._hpo_algo = hpo_algo
self._train_func = train_func
self._running_trials: dict[int, RunningTrial] = {}
self._mp = multiprocessing.get_context("spawn")
self._report_queue = self._mp.Queue()
self._uid_index = 0
self._resource_manager = get_resource_manager(
resource_type,
num_parallel_trial,
num_devices_per_trial,
)
self._main_pid = os.getpid()
append_main_proc_signal_handler(signal.SIGINT, self._terminate_signal_handler)
append_main_proc_signal_handler(signal.SIGTERM, self._terminate_signal_handler)
def run(self) -> None:
"""Run a HPO loop."""
logger.info("HPO loop starts.")
try:
while not self._hpo_algo.is_done():
if self._resource_manager.have_available_resource():
trial = self._hpo_algo.get_next_sample()
if trial is not None:
self._start_trial_process(trial)
self._remove_finished_process()
self._get_reports()
time.sleep(1)
except Exception as e:
self._terminate_all_running_processes()
raise e # noqa: TRY201
logger.info("HPO loop is done.")
self._get_reports()
self._join_all_processes()
def _start_trial_process(self, trial: Trial) -> None:
logger.info(f"{trial.id} trial is now running.")
logger.debug(f"{trial.id} hyper paramter => {trial.configuration}")
trial.status = TrialStatus.RUNNING
uid = self._get_uid()
origin_env = deepcopy(os.environ)
env = self._resource_manager.reserve_resource(uid)
if env is not None:
for key, val in env.items():
os.environ[key] = val
trial_queue = self._mp.Queue()
process = self._mp.Process(
target=_run_train,
args=(
self._train_func,
trial.get_train_configuration(),
partial(
_report_score,
recv_queue=trial_queue,
send_queue=self._report_queue,
uid=uid,
trial_id=trial.id,
),
),
)
self._running_trials[uid] = RunningTrial(process, trial, trial_queue) # type: ignore[arg-type]
try:
process.start()
except PicklingError as e:
self._raise_pickle_error(e)
except TypeError as e:
if str(e).startswith("cannot pickle"):
self._raise_pickle_error(e)
raise
os.environ.clear()
for key, val in origin_env.items():
os.environ[key] = val
def _raise_pickle_error(self, exp: Exception) -> NoReturn:
unpickleable_objs = find_unpickleable_obj(self._train_func, "self._train_func")
msg = "cannot spawn process due to objects which can't be pickled.\nfollowing objects can't be pickled.\n"
for obj in unpickleable_objs:
msg += f"{obj}\n"
raise RuntimeError(msg) from exp
def _remove_finished_process(self) -> None:
trial_to_remove = []
for uid, trial in self._running_trials.items():
if not trial.process.is_alive():
if trial.process.exitcode != 0:
self._terminate_all_running_processes()
msg = "One of HPO trials exit abnormally."
raise RuntimeError(msg)
trial.queue.close()
trial.process.join()
trial_to_remove.append(uid)
for uid in trial_to_remove:
self._running_trials[uid].trial.status = TrialStatus.STOP
self._resource_manager.release_resource(uid)
del self._running_trials[uid]
def _get_reports(self) -> None:
while not self._report_queue.empty():
report = self._report_queue.get_nowait()
trial_status = self._hpo_algo.report_score(
report["score"],
report["progress"],
report["trial_id"],
report["done"],
)
if report["uid"] in self._running_trials:
self._running_trials[report["uid"]].queue.put_nowait(trial_status)
self._hpo_algo.save_results()
def _join_all_processes(self) -> None:
for val in self._running_trials.values():
val.queue.close()
for val in self._running_trials.values():
val.process.join()
self._running_trials = {}
def _get_uid(self) -> int:
uid = self._uid_index
self._uid_index += 1
return uid
def _terminate_all_running_processes(self) -> None:
for trial in self._running_trials.values():
trial.queue.close()
process = trial.process
if process.is_alive():
logger.info(f"Kill child process {process.pid}")
process.terminate()
def _terminate_signal_handler(self, signum: Signals, frame_) -> None: # noqa: ANN001
self._terminate_all_running_processes()
singal_name = {2: "SIGINT", 15: "SIGTERM"}
logger.warning(f"{singal_name[signum]} is sent. process exited.")
def _run_train(train_func: Callable, hp_config: dict, report_func: Callable) -> None:
# set multi process method as default
multiprocessing.set_start_method(None, True)
train_func(hp_config, report_func)
def _report_score(
score: int | float,
progress: int | float,
recv_queue: multiprocessing.Queue,
send_queue: multiprocessing.Queue,
uid: Hashable,
trial_id: Hashable,
done: bool = False,
) -> TrialStatus:
logger.debug(
f"score : {score}, progress : {progress}, uid : {uid}, trial_id : {trial_id}, "
f"pid : {os.getpid()}, done : {done}",
)
try:
send_queue.put_nowait(
{
"score": score,
"progress": progress,
"uid": uid,
"trial_id": trial_id,
"pid": os.getpid(),
"done": done,
},
)
except ValueError:
return TrialStatus.STOP
try:
trial_status = recv_queue.get(timeout=3)
except queue.Empty:
return TrialStatus.RUNNING
while not recv_queue.empty():
trial_status = recv_queue.get_nowait()
logger.debug(f"trial_status : {trial_status}")
return trial_status
[docs]
def run_hpo_loop(
hpo_algo: HpoBase,
train_func: Callable,
resource_type: Literal[DeviceType.cpu, DeviceType.gpu, DeviceType.xpu] = DeviceType.gpu,
num_parallel_trial: int | None = None,
num_devices_per_trial: int = 1,
) -> None:
"""Run the HPO loop.
Args:
hpo_algo (HpoBase): HPO algorithms.
train_func (Callable): Function to train a model.
resource_type (DeviceType.cpu | DeviceType.gpu | DeviceType.gpu, optional):
Which type of resource to use. If can be changed depending on environment. Defaults to DeviceType.gpu.
num_parallel_trial (int | None, optional): How many trials to run in parallel.
It's used for CPUResourceManager. Defaults to None.
num_devices_per_trial (int, optional): How many GPUs are used for a single trial.
It's used for GPUResourceManager. Defaults to 1.
"""
hpo_loop = HpoLoop(hpo_algo, train_func, resource_type, num_parallel_trial, num_devices_per_trial)
hpo_loop.run()