Source code for otx.core.utils.instantiators

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Instantiator functions for OTX engine components."""

from __future__ import annotations

import inspect
from functools import partial
from typing import TYPE_CHECKING

from lightning.pytorch.cli import instantiate_class

from . import pylogger

if TYPE_CHECKING:
    from lightning import Callback
    from lightning.pytorch.loggers import Logger
    from torch.utils.data import Dataset, Sampler

    from otx.core.config.data import SamplerConfig


log = pylogger.get_pylogger(__name__)


[docs] def instantiate_callbacks(callbacks_cfg: list) -> list[Callback]: """Instantiate a list of callbacks based on the provided configuration. Args: callbacks_cfg (list): A list of callback configurations. Returns: list[Callback]: A list of instantiated callbacks. """ callbacks: list[Callback] = [] if not callbacks_cfg: log.warning("No callback configs found! Skipping..") return callbacks for cb_conf in callbacks_cfg: if isinstance(cb_conf, dict) and "class_path" in cb_conf: log.info(f"Instantiating callback <{cb_conf['class_path']}>") callbacks.append(instantiate_class(args=(), init=cb_conf)) return callbacks
[docs] def instantiate_loggers(logger_cfg: list | None) -> list[Logger]: """Instantiate loggers based on the provided logger configuration. Args: logger_cfg (list | None): The logger configuration. Returns: list[Logger]: The list of instantiated loggers. """ logger: list[Logger] = [] if not logger_cfg: log.warning("No logger configs found! Skipping...") return logger for lg_conf in logger_cfg: if isinstance(lg_conf, dict) and "class_path" in lg_conf: log.info(f"Instantiating logger <{lg_conf['class_path']}>") logger.append(instantiate_class(args=(), init=lg_conf)) return logger
[docs] def partial_instantiate_class(init: list | dict | None) -> list[partial] | None: """Partially instantiates a class with the given initialization arguments. Copy from lightning.pytorch.cli.instantiate_class and modify it to use partial. Args: init (list | dict | None): A dictionary containing the initialization arguments. It should have the following each keys: - "init_args" (dict): A dictionary of keyword arguments to be passed to the class constructor. - "class_path" (str): The fully qualified path of the class to be instantiated. Returns: list[partial] | None: A partial object representing the partially instantiated class. """ if not init: return None if not isinstance(init, list): init = [init] items: list[partial] = [] for item in init: kwargs = item.get("init_args", {}) class_module, class_name = item["class_path"].rsplit(".", 1) module = __import__(class_module, fromlist=[class_name]) args_class = getattr(module, class_name) items.append(partial(args_class, **kwargs)) return items
[docs] def instantiate_sampler(sampler_config: SamplerConfig, dataset: Dataset, **kwargs) -> Sampler: """Instantiate a sampler object based on the provided configuration. Args: sampler_config (SamplerConfig): The configuration object for the sampler. dataset (Dataset): The dataset object to be sampled. **kwargs: Additional keyword arguments to be passed to the sampler's constructor. Returns: Sampler: The instantiated sampler object. """ class_module, class_name = sampler_config.class_path.rsplit(".", 1) module = __import__(class_module, fromlist=[class_name]) sampler_class = getattr(module, class_name) init_signature = list(inspect.signature(sampler_class.__init__).parameters.keys()) if "batch_size" not in init_signature: kwargs.pop("batch_size", None) sampler_kwargs = {**sampler_config.init_args, **kwargs} return sampler_class(dataset, **sampler_kwargs)