Source code for otx.core.config.data

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Config data type objects for data."""
# NOTE: omegaconf would fail to parse dataclass with `from __future__ import annotations` in Python 3.8, 3.9
# ruff: noqa: FA100

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any

from otx.core.types.transformer_libs import TransformLibType


[docs] @dataclass class SubsetConfig: """DTO for dataset subset configuration. Attributes: batch_size (int): Batch size produced. subset_name (str): Datumaro Dataset's subset name for this subset config. It can differ from the actual usage (e.g., 'val' for the validation subset config). transforms (list[dict[str, Any] | Transform] | Compose): List of actually used transforms. It accepts a list of `torchvision.transforms.v2.*` Python objects or `torchvision.transforms.v2.Compose` for `TransformLibType.TORCHVISION`. Otherwise, it takes a Python dictionary that fits the configuration style used in mmcv (`TransformLibType.MMCV`, `TransformLibType.MMPRETRAIN`, ...). transform_lib_type (TransformLibType): Transform library type used by this subset. num_workers (int): Number of workers for the dataloader of this subset. input_size (int | tuple[int, int] | None) : input size model expects. If $(input_size) exists in transforms, it will be replaced with this value. Example: ```python train_subset_config = SubsetConfig( batch_size=64, subset_name="train", transforms=v2.Compose( [ v2.RandomResizedCrop(size=(224, 224), antialias=True), v2.RandomHorizontalFlip(p=0.5), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ], ) transform_lib_type=TransformLibType.TORCHVISION, num_workers=2, ) ``` """ batch_size: int subset_name: str # TODO (vinnamki): Revisit data configuration objects to support a union type in structured config # Omegaconf does not allow to have a union type, https://github.com/omry/omegaconf/issues/144 transforms: list[dict[str, Any]] transform_lib_type: TransformLibType = TransformLibType.TORCHVISION num_workers: int = 2 sampler: SamplerConfig = field(default_factory=lambda: SamplerConfig()) to_tv_image: bool = True input_size: ( Any ) = None # type is `int | tuple[int, int] | None` TODO (eunwoosh): Revisit after error above is solved
[docs] @dataclass class TileConfig: """DTO for tiler configuration.""" enable_tiler: bool = False enable_adaptive_tiling: bool = True tile_size: tuple[int, int] = (400, 400) overlap: float = 0.2 iou_threshold: float = 0.45 max_num_instances: int = 1500 object_tile_ratio: float = 0.03 sampling_ratio: float = 1.0 with_full_img: bool = False
[docs] def clone(self) -> TileConfig: """Return a deep copied one of this instance.""" return deepcopy(self)
[docs] @dataclass class VisualPromptingConfig: """DTO for visual prompting data module configuration.""" use_bbox: bool = False use_point: bool = False
[docs] @dataclass class UnlabeledDataConfig(SubsetConfig): """DTO for unlabeled data.""" data_root: str | None = None data_format: str = "image_dir" batch_size: int = 0 subset_name: str = "unlabeled" # TODO (harimkang): If not multi-transform, support for list type, as should support for other subsets. transforms: dict[str, list[dict[str, Any]]] = field(default_factory=dict) # type: ignore[assignment] transform_lib_type: TransformLibType = TransformLibType.TORCHVISION num_workers: int = 2 to_tv_image: bool = True
[docs] @dataclass class SamplerConfig: """Configuration class for defining the sampler used in the data loading process. This is passed in the form of a dataclass, which is instantiated when the dataloader is created. [TODO]: Need to replace this with a proper Sampler class. Currently, SamplerConfig, which belongs to the sampler of SubsetConfig, belongs to the nested dataclass of dataclass, which is not easy to instantiate from the CLI. So currently replace sampler with a corresponding dataclass that resembles the configuration of another object, providing limited functionality. """ class_path: str = "torch.utils.data.RandomSampler" init_args: dict[str, Any] = field(default_factory=dict)