# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""LightningDataModule extension for OTX."""
from __future__ import annotations
import logging as log
from typing import TYPE_CHECKING, Literal
from datumaro import Dataset as DmDataset
from lightning import LightningDataModule
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, RandomSampler
from otx.core.config.data import TileConfig, VisualPromptingConfig
from otx.core.data.dataset.tile import OTXTileDatasetFactory
from otx.core.data.factory import OTXDatasetFactory
from otx.core.data.mem_cache import (
MemCacheHandlerSingleton,
parse_mem_cache_size_to_int,
)
from otx.core.data.pre_filtering import pre_filtering
from otx.core.data.utils import adapt_input_size_to_dataset, adapt_tile_config
from otx.core.types.device import DeviceType
from otx.core.types.image import ImageColorChannel
from otx.core.types.label import LabelInfo
from otx.core.types.task import OTXTaskType
from otx.core.utils.instantiators import instantiate_sampler
from otx.core.utils.utils import get_adaptive_num_workers
if TYPE_CHECKING:
from lightning.pytorch.utilities.parsing import AttributeDict
from otx.core.config.data import SubsetConfig
from otx.core.data.dataset.base import OTXDataset
[docs]
class OTXDataModule(LightningDataModule):
"""LightningDataModule extension for OTX pipeline.
Args:
input_size (int | tuple[int, int] | None, optional):
Final image or video shape of data after data transformation. It'll be applied to all subset configs
If it's not None. Defaults to None.
adaptive_input_size (Literal["auto", "downscale"] | None, optional):
The adaptive input size mode. If it's set, appropriate input size is found by analyzing dataset.
"auto" can find both bigger and smaller input size than current input size and "downscale" uses only
smaller size than default setting. Defaults to None.
input_size_multiplier (int, optional):
adaptive_input_size will finds multiple of input_size_multiplier value if it's set. It's usefull when
a model requries multiple of specific value as input_size. Defaults to 1.
"""
def __init__( # noqa: PLR0913
self,
task: OTXTaskType,
data_format: str,
data_root: str,
train_subset: SubsetConfig,
val_subset: SubsetConfig,
test_subset: SubsetConfig,
tile_config: TileConfig = TileConfig(enable_tiler=False),
vpm_config: VisualPromptingConfig = VisualPromptingConfig(), # noqa: B008
mem_cache_size: str = "1GB",
mem_cache_img_max_size: tuple[int, int] | None = None,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
include_polygons: bool = False,
ignore_index: int = 255,
unannotated_items_ratio: float = 0.0,
auto_num_workers: bool = False,
device: DeviceType = DeviceType.auto,
input_size: int | tuple[int, int] | None = None,
adaptive_input_size: Literal["auto", "downscale"] | None = None,
input_size_multiplier: int = 1,
) -> None:
"""Constructor."""
super().__init__()
self.task = task
self.data_format = data_format
self.data_root = data_root
self.train_subset = train_subset
self.val_subset = val_subset
self.test_subset = test_subset
self.tile_config = tile_config
self.vpm_config = vpm_config
self.mem_cache_size = mem_cache_size
self.mem_cache_img_max_size = mem_cache_img_max_size
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.include_polygons = include_polygons
self.ignore_index = ignore_index
self.unannotated_items_ratio = unannotated_items_ratio
self.auto_num_workers = auto_num_workers
self.device = device
self.subsets: dict[str, OTXDataset] = {}
self.save_hyperparameters(ignore=["input_size"])
dataset = DmDataset.import_from(self.data_root, format=self.data_format)
if self.task != OTXTaskType.H_LABEL_CLS and not (
self.task == OTXTaskType.KEYPOINT_DETECTION and self.data_format == "arrow"
):
dataset = pre_filtering(
dataset,
self.data_format,
self.unannotated_items_ratio,
ignore_index=self.ignore_index if self.task == "SEMANTIC_SEGMENTATION" else None,
)
if adaptive_input_size is not None:
input_size = adapt_input_size_to_dataset(
dataset,
self.task,
input_size,
adaptive_input_size == "downscale",
input_size_multiplier,
)
if input_size is not None:
for subset_cfg in [train_subset, val_subset, test_subset]:
if subset_cfg.input_size is None:
subset_cfg.input_size = input_size
self.input_size = input_size
if self.tile_config.enable_tiler and self.tile_config.enable_adaptive_tiling:
adapt_tile_config(self.tile_config, dataset=dataset, task=self.task)
config_mapping = {
self.train_subset.subset_name: self.train_subset,
self.val_subset.subset_name: self.val_subset,
self.test_subset.subset_name: self.test_subset,
}
if self.auto_num_workers:
if self.device not in [DeviceType.gpu, DeviceType.auto]:
log.warning(
"Only GPU device type support auto_num_workers. "
f"Current deveice type is {self.device!s}. auto_num_workers is skipped.",
)
elif (num_workers := get_adaptive_num_workers()) is not None:
for subset_name, subset_config in config_mapping.items():
log.info(
f"num_workers of {subset_name} subset is changed : "
f"{subset_config.num_workers} -> {num_workers}",
)
subset_config.num_workers = num_workers
mem_size = parse_mem_cache_size_to_int(mem_cache_size)
mem_cache_mode = (
"singleprocessing"
if all(config.num_workers == 0 for config in config_mapping.values())
else "multiprocessing"
)
mem_cache_handler = MemCacheHandlerSingleton.create(
mode=mem_cache_mode,
mem_size=mem_size,
)
label_infos: list[LabelInfo] = []
for name, dm_subset in dataset.subsets().items():
if name not in config_mapping:
log.warning(f"{name} is not available. Skip it")
continue
dataset = OTXDatasetFactory.create(
task=self.task,
dm_subset=dm_subset.as_dataset(),
cfg_subset=config_mapping[name],
mem_cache_handler=mem_cache_handler,
data_format=self.data_format,
mem_cache_img_max_size=mem_cache_img_max_size,
image_color_channel=image_color_channel,
stack_images=stack_images,
include_polygons=include_polygons,
ignore_index=ignore_index,
vpm_config=vpm_config,
)
if self.tile_config.enable_tiler:
dataset = OTXTileDatasetFactory.create(
task=self.task,
dataset=dataset,
tile_config=self.tile_config,
)
self.subsets[name] = dataset
label_infos += [self.subsets[name].label_info]
log.info(f"Add name: {name}, self.subsets: {self.subsets}")
if self._is_meta_info_valid(label_infos) is False:
msg = "All data meta infos of subsets should be the same."
raise ValueError(msg)
self.label_info = next(iter(label_infos))
def _is_meta_info_valid(self, label_infos: list[LabelInfo]) -> bool:
"""Check whether there are mismatches in the metainfo for the all subsets."""
return bool(all(label_info == label_infos[0] for label_info in label_infos))
def _get_dataset(self, subset: str) -> OTXDataset:
if (dataset := self.subsets.get(subset)) is None:
msg = f"Dataset has no '{subset}'. Available subsets = {list(self.subsets.keys())}"
raise KeyError(msg)
return dataset
[docs]
def train_dataloader(self) -> DataLoader:
"""Get train dataloader."""
config = self.train_subset
dataset = self._get_dataset(config.subset_name)
sampler = instantiate_sampler(config.sampler, dataset=dataset, batch_size=config.batch_size)
common_args = {
"dataset": dataset,
"batch_size": config.batch_size,
"num_workers": config.num_workers,
"pin_memory": True,
"collate_fn": dataset.collate_fn,
"persistent_workers": config.num_workers > 0,
"sampler": sampler,
"shuffle": sampler is None,
}
tile_config = self.tile_config
if tile_config.enable_tiler and tile_config.sampling_ratio < 1:
num_samples = max(1, int(len(dataset) * tile_config.sampling_ratio))
log.info(f"Using tiled sampling with {num_samples} samples")
common_args.update(
{
"shuffle": False,
"sampler": RandomSampler(dataset, num_samples=num_samples),
},
)
return DataLoader(**common_args)
[docs]
def val_dataloader(self) -> DataLoader:
"""Get val dataloader."""
config = self.val_subset
dataset = self._get_dataset(config.subset_name)
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
collate_fn=dataset.collate_fn,
persistent_workers=config.num_workers > 0,
)
[docs]
def test_dataloader(self) -> DataLoader:
"""Get test dataloader."""
config = self.test_subset
dataset = self._get_dataset(config.subset_name)
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
collate_fn=dataset.collate_fn,
persistent_workers=config.num_workers > 0,
)
[docs]
def predict_dataloader(self) -> DataLoader:
"""Get test dataloader."""
config = self.test_subset
dataset = self._get_dataset(config.subset_name)
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
collate_fn=dataset.collate_fn,
persistent_workers=config.num_workers > 0,
)
[docs]
def setup(self, stage: str) -> None:
"""Setup for each stage."""
[docs]
def teardown(self, stage: str) -> None:
"""Teardown for each stage."""
# clean up after fit or test
# called on every process in DDP
@property
def hparams_initial(self) -> AttributeDict:
"""The collection of hyperparameters saved with `save_hyperparameters()`. It is read-only.
The reason why we override is that we have some custom resolvers for `DictConfig`.
Some resolved Python objects has not a primitive type, so that is not loggable without errors.
Therefore, we need to unresolve it this time.
"""
hp = super().hparams_initial
for key, value in hp.items():
if isinstance(value, DictConfig):
# It should be unresolved to make it loggable
hp[key] = OmegaConf.to_container(value, resolve=False)
return hp
def __reduce__(self):
"""Re-initialize object when unpickled."""
return (
self.__class__,
(
self.task,
self.data_format,
self.data_root,
self.train_subset,
self.val_subset,
self.test_subset,
self.tile_config,
self.vpm_config,
self.mem_cache_size,
self.mem_cache_img_max_size,
self.image_color_channel,
self.stack_images,
self.include_polygons,
self.ignore_index,
self.unannotated_items_ratio,
self.auto_num_workers,
self.device,
self.input_size,
),
)