Source code for otx.core.config

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Config data type objects."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, _SpecialForm

import yaml

from otx.core.types.task import OTXTaskType, OTXTrainType

if TYPE_CHECKING:
    from torch import dtype


[docs] def as_int_tuple(*args) -> tuple[int, ...]: """Resolve YAML list into Python integer tuple. Example: YAML file example:: ```yaml mem_cache_img_max_size: ${as_int_tuple:500,500} ``` """ return tuple(int(arg) for arg in args)
[docs] def as_torch_dtype(arg: str) -> dtype: """Resolve YAML string into PyTorch dtype. Example: YAML file example:: ```yaml transforms: - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} scale: True ``` """ import torch mapping = { "float32": torch.float32, "float": torch.float, "float64": torch.float64, "double": torch.double, "float16": torch.float16, "bfloat16": torch.bfloat16, "half": torch.half, "uint8": torch.uint8, "int8": torch.int8, "int16": torch.int16, "short": torch.short, "int32": torch.int32, "int": torch.int, "int64": torch.int64, "long": torch.long, "complex32": torch.complex32, "complex64": torch.complex64, "cfloat": torch.cfloat, "complex128": torch.complex128, "cdouble": torch.cdouble, "quint8": torch.quint8, "qint8": torch.qint8, "qint32": torch.qint32, "bool": torch.bool, "quint4x2": torch.quint4x2, "quint2x4": torch.quint2x4, } prefix = "torch." if not arg.startswith(prefix): msg = f"arg={arg} should start with the `torch.` prefix" raise ValueError(msg) key = arg[len(prefix) :] return mapping[key]
[docs] def dtype_representer(dumper: yaml.Dumper | yaml.representer.SafeRepresenter, data: dtype) -> yaml.ScalarNode: """Custom representer for converting dtype object to YAML sequence node. Args: dumper (yaml.Dumper): The YAML dumper object. data (dtype): The dtype object to be converted. Returns: yaml.Node: The converted YAML node. """ return dumper.represent_str("${as_torch_dtype:" + str(data) + "}")
[docs] def any_representer(dumper: yaml.Dumper | yaml.representer.SafeRepresenter, data: Any) -> yaml.ScalarNode: # noqa: ANN401 """Representer function that converts any data to a YAML node. Args: dumper (yaml.Dumper | yaml.representer.SafeRepresenter): The YAML dumper or safe representer. data (Any): The data to be represented. Returns: yaml.Node: The YAML node representing the data. """ return dumper.represent_none(data)
[docs] def ignore_aliases(self: yaml.representer.SafeRepresenter, data: Any) -> bool: # noqa: ARG001, ANN401 """Determine whether to ignore aliases in YAML representation. Args: data: The data to check. Returns: bool | None: True if aliases should be ignored, None otherwise. """ from torch import dtype if data is None: return True if isinstance(data, tuple) and data == (): return True if isinstance(data, (str, bytes, bool, int, float, dtype)): return True return None
[docs] def otx_str_type_representer( dumper: yaml.Dumper | yaml.representer.SafeRepresenter, data: OTXTaskType | OTXTrainType, ) -> yaml.ScalarNode: """Representer function for converting OTXTaskType or OTXTrainType to a YAML string representation. Args: dumper (yaml.Dumper | yaml.representer.SafeRepresenter): The YAML dumper or safe representer object. data (OTXTaskType | OTXTrainType): The OTXTaskType or OTXTrainType object to be represented. Returns: yaml.ScalarNode: The YAML ScalarNode representation of the given object. """ return dumper.represent_scalar("tag:yaml.org,2002:str", str(data.value))
[docs] def register_configs() -> None: """Register custom resolvers.""" from omegaconf import OmegaConf OmegaConf.register_new_resolver("as_int_tuple", as_int_tuple, replace=True) OmegaConf.register_new_resolver("as_torch_dtype", as_torch_dtype, replace=True) from torch import dtype yaml.add_representer(dtype, dtype_representer) # For lightnig_logs # For jsonargparse's SafeDumper yaml.SafeDumper.add_representer(dtype, dtype_representer) yaml.SafeDumper.add_representer(_SpecialForm, any_representer) # typing.Any for DictConfig yaml.SafeDumper.ignore_aliases = ignore_aliases # type: ignore # noqa: PGH003 yaml.SafeDumper.add_representer(OTXTaskType, otx_str_type_representer) yaml.SafeDumper.add_representer(OTXTrainType, otx_str_type_representer)
register_configs()