Source code for otx.algo.callbacks.adaptive_early_stopping

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Callback for early stopping with warmup possibility."""

from __future__ import annotations

from typing import TYPE_CHECKING

from lightning.pytorch.callbacks.early_stopping import EarlyStopping

if TYPE_CHECKING:
    import lightning.pytorch as pl


[docs] class EarlyStoppingWithWarmup(EarlyStopping): """EarlyStoppingWithWarmup callback.""" def __init__( self, monitor: str, min_delta: float = 0.0, patience: int = 10, verbose: bool = False, mode: str = "min", strict: bool = True, check_finite: bool = True, stopping_threshold: float | None = None, divergence_threshold: float | None = None, check_on_train_epoch_end: bool | None = None, log_rank_zero_only: bool = False, warmup_iters: int = 100, warmup_epochs: int = 3, ): """EarlyStoppingWithWarmup callback. Args: monitor (str): The metric to monitor. min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement. Defaults to 0.0. patience (int, optional): Number of epochs with no improvement after which training will be stopped. Defaults to 3. verbose (bool, optional): If True, prints messages to stdout. Defaults to False. mode (str, optional): One of {"min", "max"}. In "min" mode, training will stop when the quantity monitored has stopped decreasing. In "max" mode, it will stop when the quantity monitored has stopped increasing. Defaults to "min". strict (bool, optional): If True, the monitored quantity must improve according to the mode for it to be considered an improvement. Defaults to True. check_finite (bool, optional): If True, check that the monitored quantity is finite before considering an improvement. Defaults to True. stopping_threshold (float | None, optional): The threshold to stop training. Defaults to None. divergence_threshold (float | None, optional): The threshold for divergence detection. Defaults to None. check_on_train_epoch_end (bool | None, optional): If True, checks the stopping criterion on train_epoch_end. Defaults to None. log_rank_zero_only (bool, optional): If True, logs should only be printed from rank 0. Defaults to False. warmup_iters (int, optional): Number of warmup iterations. Defaults to 100. warmup_epochs (int, optional): Number of warmup epochs. Defaults to 3. """ super().__init__( monitor=monitor, min_delta=min_delta, patience=patience, verbose=verbose, mode=mode, strict=strict, check_finite=check_finite, stopping_threshold=stopping_threshold, divergence_threshold=divergence_threshold, check_on_train_epoch_end=check_on_train_epoch_end, log_rank_zero_only=log_rank_zero_only, ) # two thresholds to have invariant to extra small datasets and larger datasets self.warmup_iters = warmup_iters self.warmup_epochs = warmup_epochs def _should_skip_check(self, trainer: pl.Trainer) -> bool: warmup_threshold = max(self.warmup_epochs * trainer.num_training_batches, self.warmup_iters) return super()._should_skip_check(trainer) or trainer.global_step < warmup_threshold