Source code for otx.core.schedulers.callable

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Scheduler callable to support hyper-parameter optimization (HPO) algorithm."""

from __future__ import annotations

import importlib
import inspect
from typing import TYPE_CHECKING, Any

from lightning.pytorch.cli import ReduceLROnPlateau
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau as TorchReduceLROnPlateau

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable


[docs] class SchedulerCallableSupportHPO: """LR scheduler callable supports OTX hyper-parameter optimization (HPO) algorithm. It makes SchedulerCallable pickelable and accessible to parameters. It is used for HPO and adaptive batch size. Args: scheduler_cls: `LRScheduler` class type or string class import path. See examples for details. scheduler_kwargs: Keyword arguments used for the initialization of the given `scheduler_cls`. Examples: This is an example to create `MobileNetV3ForMulticlassCls` with a `StepLR` lr scheduler and custom configurations. ```python from torch.optim.lr_scheduler import StepLR from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls model = MobileNetV3ForMulticlassCls( num_classes=3, scheduler=SchedulerCallableSupportHPO( scheduler_cls=StepLR, scheduler_kwargs={ "step_size": 10, "gamma": 0.5, }, ), ) ``` It can be created from the string class import path such as ```python from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls model = MobileNetV3ForMulticlassCls( num_classes=3, optimizer=SchedulerCallableSupportHPO( scheduler_cls="torch.optim.lr_scheduler.StepLR", scheduler_kwargs={ "step_size": 10, "gamma": 0.5, }, ), ) ``` """ def __init__( self, scheduler_cls: type[LRScheduler] | str, scheduler_kwargs: dict[str, int | float | bool | str], ): if isinstance(scheduler_cls, str): splited = scheduler_cls.split(".") module_path, class_name = ".".join(splited[:-1]), splited[-1] module = importlib.import_module(module_path) self.scheduler_init: type[LRScheduler] = getattr(module, class_name) self.scheduler_path = scheduler_cls elif issubclass(scheduler_cls, LRScheduler | ReduceLROnPlateau): self.scheduler_init = scheduler_cls self.scheduler_path = scheduler_cls.__module__ + "." + scheduler_cls.__qualname__ else: raise TypeError(scheduler_cls) self.scheduler_kwargs = scheduler_kwargs self.__dict__.update(scheduler_kwargs)
[docs] def __call__(self, optimizer: Optimizer) -> LRScheduler: """Create `torch.optim.LRScheduler` instance for the given parameters.""" return self.scheduler_init(optimizer, **self.scheduler_kwargs)
def __reduce__(self) -> str | tuple[Any, ...]: return self.__class__, ( self.scheduler_path, self.scheduler_kwargs, )
[docs] @classmethod def from_callable(cls, func: LRSchedulerCallable) -> SchedulerCallableSupportHPO: """Create this class instance from an existing optimizer callable.""" dummy_params = [nn.Parameter()] optimizer = Optimizer(dummy_params, {"lr": 1.0}) scheduler = func(optimizer) allow_names = set(inspect.signature(scheduler.__class__).parameters) if isinstance(scheduler, ReduceLROnPlateau): # NOTE: Other arguments except "monitor", such as "patience" # are not included in the signature of ReduceLROnPlateau.__init__() allow_names.update(key for key in inspect.signature(TorchReduceLROnPlateau).parameters) block_names = {"optimizer", "last_epoch"} scheduler_kwargs = { key: value for key, value in scheduler.state_dict().items() if key in allow_names and key not in block_names } return SchedulerCallableSupportHPO( scheduler_cls=scheduler.__class__, scheduler_kwargs=scheduler_kwargs, )