Source code for otx.core.schedulers.warmup_schedulers
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Warm-up schedulers for the OTX2.0."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from otx.core.schedulers.callable import SchedulerCallableSupportHPO
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer
[docs]
class LinearWarmupScheduler(LambdaLR):
"""Linear Warmup scheduler.
Args:
optimizer (Optimizer): Optimizer to apply the scheduler.
num_warmup_steps (int): Learning rate will linearly increased during the period same as this number.
interval (Literal["step", "epoch"]): If "epoch", count the number of epochs for the warmup period.
Otherwise, the iteration step will be the warmup period.
"""
def __init__(
self,
optimizer: Optimizer,
num_warmup_steps: int = 1000,
interval: Literal["step", "epoch"] = "step",
):
if not num_warmup_steps > 0:
msg = f"num_warmup_steps should be > 0, got {num_warmup_steps}"
raise ValueError(msg)
self.num_warmup_steps = num_warmup_steps
self.interval = interval
super().__init__(optimizer, lambda step: min((step + 1.0) / self.num_warmup_steps, 1.0))
[docs]
def step(self, epoch: int | None = None) -> None:
"""Overriding the step to disable the warmup scheduler after n_steps."""
if self.activated:
super().step(epoch)
@property
def activated(self) -> bool:
"""If true, the current step count is less than the num_warmup_steps."""
return self._step_count <= self.num_warmup_steps
[docs]
class LinearWarmupSchedulerCallable:
"""This callable can create the given main LR scheduler and `LinearWarmupScheduler` at the same time.
Args:
main_scheduler_callable: Callable to create a LR scheduler that will be mainly used.
num_warmup_steps: Learning rate will linearly increased during the period same as this number.
If it is less than equal to zero, do not create `LinearWarmupScheduler`.
warmup_interval: If "epoch", count the number of epochs for the warmup period.
Otherwise, the iteration step will be the warmup period.
monitor: If given, override the main scheduler's `monitor` attribute.
"""
def __init__(
self,
main_scheduler_callable: LRSchedulerCallable,
num_warmup_steps: int = 0,
warmup_interval: Literal["step", "epoch"] = "step",
monitor: str | None = None,
):
self.main_scheduler_callable = SchedulerCallableSupportHPO.from_callable(main_scheduler_callable)
self.num_warmup_steps = num_warmup_steps
self.warmup_interval = warmup_interval
self.monitor = monitor
[docs]
def __call__(self, optimizer: Optimizer) -> list[LRScheduler | ReduceLROnPlateau]:
"""Create a list of lr schedulers."""
main_scheduler = self.main_scheduler_callable(optimizer)
if self.monitor and hasattr(main_scheduler, "monitor"):
main_scheduler.monitor = self.monitor
schedulers = [main_scheduler]
if self.num_warmup_steps > 0:
schedulers += [
LinearWarmupScheduler(
optimizer=optimizer,
num_warmup_steps=self.num_warmup_steps,
interval=self.warmup_interval,
),
]
return schedulers