Source code for otx.core.utils.cache

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Cache Class for Trainer kwargs."""

from __future__ import annotations

import inspect
import logging
from typing import Any

logger = logging.getLogger(__name__)


[docs] class TrainerArgumentsCache: """Cache arguments. Since the Engine class accepts PyTorch Lightning Trainer arguments, we store these arguments using this class before the trainer is instantiated. Args: (**kwargs): Trainer arguments that are cached Example: >>> conf = OmegaConf.load("config.yaml") >>> cache = TrainerArgumentsCache(**conf) >>> cache.args { ... 'max_epochs': 100, 'val_check_interval': 0 } >>> config = {"max_epochs": 1, "val_check_interval": 1.0} >>> cache.update(config) Overriding max_epochs from 100 with 1 Overriding val_check_interval from 0 with 1.0 >>> cache.args { ... 'max_epochs': 1, 'val_check_interval': 1.0 } """ def __init__(self, **kwargs) -> None: self._cached_args = {**kwargs} self.is_trainer_args_identical = False
[docs] def update(self, **kwargs) -> None: """Replace cached arguments with arguments retrieved from the model.""" for key, value in kwargs.items(): if value is None: continue if key in self._cached_args and self._cached_args[key] != value: logger.info( f"Overriding {key} from {self._cached_args[key]} with {value}", ) self._cached_args[key] = value
[docs] def requires_update(self, **kwargs) -> bool: """Checks if the cached arguments need to be updated based on the provided keyword arguments. Args: **kwargs: The keyword arguments to compare with the cached arguments. Returns: bool: True if any of the cached arguments need to be updated, False otherwise. """ return not self.is_trainer_args_identical or any( key in self._cached_args and self._cached_args[key] != value for key, value in kwargs.items() )
@property def args(self) -> dict[str, Any]: """Returns the cached arguments. Returns: dict[str, Any]: The cached arguments. """ return self._cached_args
[docs] @staticmethod def get_trainer_constructor_args() -> set[str]: """Get the set of arguments accepted by the Trainer class constructor. Returns: set[str]: A set of argument names accepted by the Trainer class constructor. """ from lightning import Trainer sig = inspect.signature(Trainer.__init__) return set(sig.parameters.keys())