Source code for otx.core.data.manager.dataset_manager
"""Datumaro Helper."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# pylint: disable=invalid-name
import os
from typing import List, Optional, Tuple, Union
import datumaro
from datumaro.components.dataset import Dataset, DatasetSubset
from datumaro.components.dataset_base import DatasetItem
from datumaro.plugins.splitter import Split
[docs]
class DatasetManager:
"""The aim of DatasetManager is support datumaro functions at easy use.
All kind of functions implemented in Datumaro are supported by this Manager.
Since DatasetManager just wraps Datumaro's function,
All methods are implemented as static method.
"""
[docs]
@staticmethod
def get_train_dataset(dataset: Dataset) -> DatasetSubset:
"""Returns train dataset."""
subsets = dataset.subsets()
train_dataset = subsets.get("train", None)
if train_dataset is not None:
return train_dataset
for k, v in subsets.items():
if "train" in k or "default" in k:
return v
raise ValueError("Can't find training data.")
[docs]
@staticmethod
def get_val_dataset(dataset: Dataset) -> Union[DatasetSubset, None]:
"""Returns validation dataset."""
subsets = dataset.subsets()
val_dataset = subsets.get("val", None)
if val_dataset is not None:
return val_dataset
for k, v in subsets.items():
if "val" in k:
return v
return None
[docs]
@staticmethod
def get_data_format(data_root: str) -> str:
"""Find the format of dataset."""
data_root = os.path.abspath(data_root)
data_format: str = ""
# TODO #
# Currently, below `if/else` statements is mandatory
# because Datumaro can't detect the multi-cvat and mvtec.
# After, the upgrade of Datumaro, below codes will be changed.
if DatasetManager.is_cvat_format(data_root):
data_format = "multi-cvat"
elif DatasetManager.is_mvtec_format(data_root):
data_format = "mvtec"
else:
data_formats = datumaro.Environment().detect_dataset(data_root)
# TODO: how to avoid hard-coded part
data_format = data_formats[0] if "imagenet" not in data_formats else "imagenet"
print(f"[*] Detected dataset format: {data_format}")
return data_format
[docs]
@staticmethod
def get_image_path(data_item: DatasetItem) -> Optional[str]:
"""Returns the path of image."""
if hasattr(data_item.media, "path"):
return data_item.media.path
return None
[docs]
@staticmethod
def export_dataset(dataset: Dataset, output_dir: str, data_format: str, save_media=True):
"""Export the Datumaro Dataset."""
return dataset.export(output_dir, data_format, save_media=save_media)
[docs]
@staticmethod
def import_dataset(data_root: str, data_format: str, subset: Optional[str] = None) -> dict:
"""Import dataset."""
return Dataset.import_from(data_root, format=data_format, subset=subset)
[docs]
@staticmethod
def auto_split(task: str, dataset: Dataset, split_ratio: List[Tuple[str, float]]) -> dict:
"""Automatically split the dataset: train --> train/val."""
splitter = Split(dataset, task.lower(), split_ratio)
return splitter.subsets()
[docs]
@staticmethod
def is_cvat_format(path: str) -> bool:
"""Detect whether data path is CVAT format or not.
Currently, we used multi-video CVAT format for Action tasks.
This function can detect the multi-video CVAT format.
Multi-video CVAT format
root
|--video_0
|--images
|--frame0001.png
|--annotations.xml
|--video_1
|--video_2
will be deprecated soon.
"""
cvat_format = sorted(["images", "annotations.xml"])
for sub_folder in os.listdir(path):
# video_0, video_1, ...
sub_folder_path = os.path.join(path, sub_folder)
# files must be same with cvat_format
if os.path.isdir(sub_folder_path):
files = sorted(os.listdir(sub_folder_path))
if files != cvat_format:
return False
return True
[docs]
@staticmethod
def is_mvtec_format(path: str) -> bool:
"""Detect whether data path is MVTec format or not.
Check the first-level architecture folder, to know whether the dataset is MVTec or not.
MVTec default structure like as below:
root
|--ground_truth
|--train
|--test
will be deprecated soon.
"""
mvtec_format = sorted(["ground_truth", "train", "test"])
folder_list = []
for sub_folder in os.listdir(path):
sub_folder_path = os.path.join(path, sub_folder)
# only use the folder name.
if os.path.isdir(sub_folder_path):
folder_list.append(sub_folder)
return sorted(folder_list) == mvtec_format