Source code for otx.core.optimizer.callable

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Optimizer callable to support hyper-parameter optimization (HPO) algorithm."""

from __future__ import annotations

import importlib
from typing import TYPE_CHECKING, Any

from torch import nn
from torch.optim.optimizer import Optimizer

if TYPE_CHECKING:
    from lightning.pytorch.cli import OptimizerCallable
    from torch.optim.optimizer import params_t


[docs] class OptimizerCallableSupportHPO: """Optimizer callable supports OTX hyper-parameter optimization (HPO) algorithm. It makes OptimizerCallable pickelable and accessible to parameters. It is used for HPO and adaptive batch size. Args: optimizer_cls: Optimizer class type or string class import path. See examples for details. optimizer_kwargs: Keyword arguments used for the initialization of the given `optimizer_cls`. Examples: This is an example to create `MobileNetV3ForMulticlassCls` with a `SGD` optimizer and custom configurations. ```python from torch.optim import SGD from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls model = MobileNetV3ForMulticlassCls( num_classes=3, optimizer=OptimizerCallableSupportHPO( optimizer_cls=SGD, optimizer_kwargs={ "lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4, }, ), ) ``` It can be created from the string class import path such as ```python from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls model = MobileNetV3ForMulticlassCls( num_classes=3, optimizer=OptimizerCallableSupportHPO( optimizer_cls="torch.optim.SGD", optimizer_kwargs={ "lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4, }, ), ) ``` """ def __init__( self, optimizer_cls: type[Optimizer] | str, optimizer_kwargs: dict[str, int | float | bool], ): if isinstance(optimizer_cls, str): splited = optimizer_cls.split(".") module_path, class_name = ".".join(splited[:-1]), splited[-1] module = importlib.import_module(module_path) self.optimizer_init: type[Optimizer] = getattr(module, class_name) self.optimizer_path = optimizer_cls elif issubclass(optimizer_cls, Optimizer): self.optimizer_init = optimizer_cls self.optimizer_path = optimizer_cls.__module__ + "." + optimizer_cls.__qualname__ else: raise TypeError(optimizer_cls) self.optimizer_kwargs = optimizer_kwargs self.__dict__.update(optimizer_kwargs)
[docs] def __call__(self, params: params_t) -> Optimizer: """Create `torch.optim.Optimizer` instance for the given parameters.""" return self.optimizer_init(params, **self.optimizer_kwargs)
def __reduce__(self) -> str | tuple[Any, ...]: return self.__class__, ( self.optimizer_path, self.optimizer_kwargs, )
[docs] @classmethod def from_callable(cls, func: OptimizerCallable) -> OptimizerCallableSupportHPO: """Create this class instance from an existing optimizer callable.""" dummy_params = [nn.Parameter()] optimizer = func(dummy_params) param_group = next(iter(optimizer.param_groups)) return OptimizerCallableSupportHPO( optimizer_cls=optimizer.__class__, optimizer_kwargs={key: value for key, value in param_group.items() if key != "params"}, )