Source code for otx.core.data.module

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""LightningDataModule extension for OTX."""

from __future__ import annotations

import logging as log
from typing import TYPE_CHECKING, Literal

from datumaro import Dataset as DmDataset
from lightning import LightningDataModule
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, RandomSampler

from otx.core.config.data import TileConfig, VisualPromptingConfig
from otx.core.data.dataset.tile import OTXTileDatasetFactory
from otx.core.data.factory import OTXDatasetFactory
from otx.core.data.mem_cache import (
    MemCacheHandlerSingleton,
    parse_mem_cache_size_to_int,
)
from otx.core.data.pre_filtering import pre_filtering
from otx.core.data.utils import adapt_input_size_to_dataset, adapt_tile_config
from otx.core.types.device import DeviceType
from otx.core.types.image import ImageColorChannel
from otx.core.types.label import LabelInfo
from otx.core.types.task import OTXTaskType
from otx.core.utils.instantiators import instantiate_sampler
from otx.core.utils.utils import get_adaptive_num_workers

if TYPE_CHECKING:
    from lightning.pytorch.utilities.parsing import AttributeDict

    from otx.core.config.data import SubsetConfig
    from otx.core.data.dataset.base import OTXDataset


[docs] class OTXDataModule(LightningDataModule): """LightningDataModule extension for OTX pipeline. Args: input_size (int | tuple[int, int] | None, optional): Final image or video shape of data after data transformation. It'll be applied to all subset configs If it's not None. Defaults to None. adaptive_input_size (Literal["auto", "downscale"] | None, optional): The adaptive input size mode. If it's set, appropriate input size is found by analyzing dataset. "auto" can find both bigger and smaller input size than current input size and "downscale" uses only smaller size than default setting. Defaults to None. input_size_multiplier (int, optional): adaptive_input_size will finds multiple of input_size_multiplier value if it's set. It's usefull when a model requries multiple of specific value as input_size. Defaults to 1. """ def __init__( # noqa: PLR0913 self, task: OTXTaskType, data_format: str, data_root: str, train_subset: SubsetConfig, val_subset: SubsetConfig, test_subset: SubsetConfig, tile_config: TileConfig = TileConfig(enable_tiler=False), vpm_config: VisualPromptingConfig = VisualPromptingConfig(), # noqa: B008 mem_cache_size: str = "1GB", mem_cache_img_max_size: tuple[int, int] | None = None, image_color_channel: ImageColorChannel = ImageColorChannel.RGB, stack_images: bool = True, include_polygons: bool = False, ignore_index: int = 255, unannotated_items_ratio: float = 0.0, auto_num_workers: bool = False, device: DeviceType = DeviceType.auto, input_size: int | tuple[int, int] | None = None, adaptive_input_size: Literal["auto", "downscale"] | None = None, input_size_multiplier: int = 1, ) -> None: """Constructor.""" super().__init__() self.task = task self.data_format = data_format self.data_root = data_root self.train_subset = train_subset self.val_subset = val_subset self.test_subset = test_subset self.tile_config = tile_config self.vpm_config = vpm_config self.mem_cache_size = mem_cache_size self.mem_cache_img_max_size = mem_cache_img_max_size self.image_color_channel = image_color_channel self.stack_images = stack_images self.include_polygons = include_polygons self.ignore_index = ignore_index self.unannotated_items_ratio = unannotated_items_ratio self.auto_num_workers = auto_num_workers self.device = device self.subsets: dict[str, OTXDataset] = {} self.save_hyperparameters(ignore=["input_size"]) dataset = DmDataset.import_from(self.data_root, format=self.data_format) if self.task != OTXTaskType.H_LABEL_CLS and not ( self.task == OTXTaskType.KEYPOINT_DETECTION and self.data_format == "arrow" ): dataset = pre_filtering( dataset, self.data_format, self.unannotated_items_ratio, ignore_index=self.ignore_index if self.task == "SEMANTIC_SEGMENTATION" else None, ) if adaptive_input_size is not None: input_size = adapt_input_size_to_dataset( dataset, self.task, input_size, adaptive_input_size == "downscale", input_size_multiplier, ) if input_size is not None: for subset_cfg in [train_subset, val_subset, test_subset]: if subset_cfg.input_size is None: subset_cfg.input_size = input_size self.input_size = input_size if self.tile_config.enable_tiler and self.tile_config.enable_adaptive_tiling: adapt_tile_config(self.tile_config, dataset=dataset, task=self.task) config_mapping = { self.train_subset.subset_name: self.train_subset, self.val_subset.subset_name: self.val_subset, self.test_subset.subset_name: self.test_subset, } if self.auto_num_workers: if self.device not in [DeviceType.gpu, DeviceType.auto]: log.warning( "Only GPU device type support auto_num_workers. " f"Current deveice type is {self.device!s}. auto_num_workers is skipped.", ) elif (num_workers := get_adaptive_num_workers()) is not None: for subset_name, subset_config in config_mapping.items(): log.info( f"num_workers of {subset_name} subset is changed : " f"{subset_config.num_workers} -> {num_workers}", ) subset_config.num_workers = num_workers mem_size = parse_mem_cache_size_to_int(mem_cache_size) mem_cache_mode = ( "singleprocessing" if all(config.num_workers == 0 for config in config_mapping.values()) else "multiprocessing" ) mem_cache_handler = MemCacheHandlerSingleton.create( mode=mem_cache_mode, mem_size=mem_size, ) label_infos: list[LabelInfo] = [] for name, dm_subset in dataset.subsets().items(): if name not in config_mapping: log.warning(f"{name} is not available. Skip it") continue dataset = OTXDatasetFactory.create( task=self.task, dm_subset=dm_subset.as_dataset(), cfg_subset=config_mapping[name], mem_cache_handler=mem_cache_handler, data_format=self.data_format, mem_cache_img_max_size=mem_cache_img_max_size, image_color_channel=image_color_channel, stack_images=stack_images, include_polygons=include_polygons, ignore_index=ignore_index, vpm_config=vpm_config, ) if self.tile_config.enable_tiler: dataset = OTXTileDatasetFactory.create( task=self.task, dataset=dataset, tile_config=self.tile_config, ) self.subsets[name] = dataset label_infos += [self.subsets[name].label_info] log.info(f"Add name: {name}, self.subsets: {self.subsets}") if self._is_meta_info_valid(label_infos) is False: msg = "All data meta infos of subsets should be the same." raise ValueError(msg) self.label_info = next(iter(label_infos)) def _is_meta_info_valid(self, label_infos: list[LabelInfo]) -> bool: """Check whether there are mismatches in the metainfo for the all subsets.""" return bool(all(label_info == label_infos[0] for label_info in label_infos)) def _get_dataset(self, subset: str) -> OTXDataset: if (dataset := self.subsets.get(subset)) is None: msg = f"Dataset has no '{subset}'. Available subsets = {list(self.subsets.keys())}" raise KeyError(msg) return dataset
[docs] def train_dataloader(self) -> DataLoader: """Get train dataloader.""" config = self.train_subset dataset = self._get_dataset(config.subset_name) sampler = instantiate_sampler(config.sampler, dataset=dataset, batch_size=config.batch_size) common_args = { "dataset": dataset, "batch_size": config.batch_size, "num_workers": config.num_workers, "pin_memory": True, "collate_fn": dataset.collate_fn, "persistent_workers": config.num_workers > 0, "sampler": sampler, "shuffle": sampler is None, } tile_config = self.tile_config if tile_config.enable_tiler and tile_config.sampling_ratio < 1: num_samples = max(1, int(len(dataset) * tile_config.sampling_ratio)) log.info(f"Using tiled sampling with {num_samples} samples") common_args.update( { "shuffle": False, "sampler": RandomSampler(dataset, num_samples=num_samples), }, ) return DataLoader(**common_args)
[docs] def val_dataloader(self) -> DataLoader: """Get val dataloader.""" config = self.val_subset dataset = self._get_dataset(config.subset_name) return DataLoader( dataset=dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True, collate_fn=dataset.collate_fn, persistent_workers=config.num_workers > 0, )
[docs] def test_dataloader(self) -> DataLoader: """Get test dataloader.""" config = self.test_subset dataset = self._get_dataset(config.subset_name) return DataLoader( dataset=dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True, collate_fn=dataset.collate_fn, persistent_workers=config.num_workers > 0, )
[docs] def predict_dataloader(self) -> DataLoader: """Get test dataloader.""" config = self.test_subset dataset = self._get_dataset(config.subset_name) return DataLoader( dataset=dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True, collate_fn=dataset.collate_fn, persistent_workers=config.num_workers > 0, )
[docs] def setup(self, stage: str) -> None: """Setup for each stage."""
[docs] def teardown(self, stage: str) -> None: """Teardown for each stage."""
# clean up after fit or test # called on every process in DDP @property def hparams_initial(self) -> AttributeDict: """The collection of hyperparameters saved with `save_hyperparameters()`. It is read-only. The reason why we override is that we have some custom resolvers for `DictConfig`. Some resolved Python objects has not a primitive type, so that is not loggable without errors. Therefore, we need to unresolve it this time. """ hp = super().hparams_initial for key, value in hp.items(): if isinstance(value, DictConfig): # It should be unresolved to make it loggable hp[key] = OmegaConf.to_container(value, resolve=False) return hp def __reduce__(self): """Re-initialize object when unpickled.""" return ( self.__class__, ( self.task, self.data_format, self.data_root, self.train_subset, self.val_subset, self.test_subset, self.tile_config, self.vpm_config, self.mem_cache_size, self.mem_cache_img_max_size, self.image_color_channel, self.stack_images, self.include_polygons, self.ignore_index, self.unannotated_items_ratio, self.auto_num_workers, self.device, self.input_size, ), )