Source code for otx.algorithms.common.adapters.mmcv.hooks.cancel_hook
"""Cancel hooks."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import os
from typing import Callable
from mmcv.runner import BaseRunner, EpochBasedRunner
from mmcv.runner.hooks import HOOKS, Hook
from otx.utils.logger import get_logger
logger = get_logger()
# pylint: disable=too-many-instance-attributes, protected-access, too-many-arguments, unused-argument
[docs]
@HOOKS.register_module()
class CancelTrainingHook(Hook):
"""CancelTrainingHook for Training Stopping."""
def __init__(self, interval: int = 5):
"""Periodically check whether whether a stop signal is sent to the runner during model training.
Every 'check_interval' iterations, the work_dir for the runner is checked to see if a file '.stop_training'
is present. If it is, training is stopped.
:param interval: Period for checking for stop signal, given in iterations.
"""
self.interval = interval
@staticmethod
def _check_for_stop_signal(runner: BaseRunner):
"""Log _check_for_stop_signal for CancelTrainingHook."""
work_dir = runner.work_dir
stop_filepath = os.path.join(work_dir, ".stop_training")
if os.path.exists(stop_filepath):
if isinstance(runner, EpochBasedRunner):
epoch = runner.epoch
runner._max_epochs = epoch # Force runner to stop by pretending it has reached it's max_epoch
runner.should_stop = True # Set this flag to true to stop the current training epoch
os.remove(stop_filepath)
[docs]
def after_train_iter(self, runner: BaseRunner):
"""Log after_train_iter for CancelTrainingHook."""
if not self.every_n_iters(runner, self.interval):
return
self._check_for_stop_signal(runner)
@HOOKS.register_module()
class CancelInterfaceHook(Hook):
"""Cancel interface. If called, running job will be terminated."""
def __init__(self, init_callback: Callable, interval=5):
self.on_init_callback = init_callback
self.runner = None
self.interval = interval
def cancel(self):
"""Cancel."""
logger.info("CancelInterfaceHook.cancel() is called.")
if self.runner is None:
logger.warning("runner is not configured yet. ignored this request.")
return
if self.runner.should_stop:
logger.warning("cancel already requested.")
return
if isinstance(self.runner, EpochBasedRunner):
epoch = self.runner.epoch
self.runner._max_epochs = epoch # Force runner to stop by pretending it has reached it's max_epoch
self.runner.should_stop = True # Set this flag to true to stop the current training epoch
logger.info("requested stopping to the runner")
def before_run(self, runner):
"""Before run."""
self.runner = runner
self.on_init_callback(self)
def after_run(self, runner):
"""After run."""
self.runner = None