"""OTX Core Data Adapter."""
# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
# pylint: disable=too-many-return-statements, too-many-arguments
import importlib
import os
from otx.algorithms.common.configs.training_base import TrainType
from otx.api.entities.model_template import TaskType
ADAPTERS = {
TaskType.CLASSIFICATION: {
"Incremental": {
"module_name": "classification_dataset_adapter",
"class": "ClassificationDatasetAdapter",
},
"Selfsupervised": {
"module_name": "classification_dataset_adapter",
"class": "SelfSLClassificationDatasetAdapter",
},
},
TaskType.DETECTION: {
"Incremental": {
"module_name": "detection_dataset_adapter",
"class": "DetectionDatasetAdapter",
}
},
TaskType.ROTATED_DETECTION: {
"Incremental": {
"module_name": "detection_dataset_adapter",
"class": "DetectionDatasetAdapter",
}
},
TaskType.INSTANCE_SEGMENTATION: {
"Incremental": {
"module_name": "detection_dataset_adapter",
"class": "DetectionDatasetAdapter",
}
},
TaskType.SEGMENTATION: {
"Incremental": {
"module_name": "segmentation_dataset_adapter",
"class": "SegmentationDatasetAdapter",
},
"Selfsupervised": {
"module_name": "segmentation_dataset_adapter",
"class": "SelfSLSegmentationDatasetAdapter",
},
},
TaskType.ANOMALY_CLASSIFICATION: {
"Incremental": {
"module_name": "anomaly_dataset_adapter",
"class": "AnomalyClassificationDatasetAdapter",
}
},
TaskType.ANOMALY_DETECTION: {
"Incremental": {
"module_name": "anomaly_dataset_adapter",
"class": "AnomalyDetectionDatasetAdapter",
}
},
TaskType.ANOMALY_SEGMENTATION: {
"Incremental": {
"module_name": "anomaly_dataset_adapter",
"class": "AnomalySegmentationDatasetAdapter",
}
},
}
if os.getenv("FEATURE_FLAGS_OTX_ACTION_TASKS", "0") == "1":
ADAPTERS.update(
{
TaskType.ACTION_CLASSIFICATION: {
"Incremental": {
"module_name": "action_dataset_adapter",
"class": "ActionClassificationDatasetAdapter",
}
},
TaskType.ACTION_DETECTION: {
"Incremental": {
"module_name": "action_dataset_adapter",
"class": "ActionDetectionDatasetAdapter",
}
},
}
)
# TODO: update to real template
if os.getenv("FEATURE_FLAGS_OTX_VISUAL_PROMPTING_TASKS", "0") == "1":
ADAPTERS.update(
{
TaskType.VISUAL_PROMPTING: {
"Incremental": {
"module_name": "visual_prompting_dataset_adapter",
"class": "VisualPromptingDatasetAdapter",
}
},
}
)
[docs]
def get_dataset_adapter(
task_type: TaskType,
train_type: TrainType,
train_data_roots: str = None,
train_ann_files: str = None,
val_data_roots: str = None,
val_ann_files: str = None,
test_data_roots: str = None,
test_ann_files: str = None,
unlabeled_data_roots: str = None,
unlabeled_file_list: str = None,
**kwargs,
):
"""Returns a dataset class by task type.
Args:
task_type: A task type such as ANOMALY_CLASSIFICATION, ANOMALY_DETECTION, ANOMALY_SEGMENTATION,
CLASSIFICATION, INSTANCE_SEGMENTATION, DETECTION, CLASSIFICATION, ROTATED_DETECTION, SEGMENTATION.
train_type: train type such as Incremental and Selfsupervised.
Selfsupervised is only supported for SEGMENTATION.
train_data_roots: the path of data root for training data
train_ann_files: the path of annotation file for training data
val_data_roots: the path of data root for validation data
val_ann_files: the path of annotation file for validation data
test_data_roots: the path of data root for test data
test_ann_files: the path of annotation file for test data
unlabeled_data_roots: the path of data root for unlabeled data
unlabeled_file_list: the path of unlabeled file list
kwargs: optional kwargs
"""
train_type_to_be_called = str(
train_type if train_type == TrainType.Selfsupervised.value else TrainType.Incremental.value
)
module_root = "otx.core.data.adapter."
module = importlib.import_module(module_root + ADAPTERS[task_type][train_type_to_be_called]["module_name"])
return getattr(module, ADAPTERS[task_type][train_type_to_be_called]["class"])(
task_type=task_type,
train_data_roots=train_data_roots,
train_ann_files=train_ann_files,
val_data_roots=val_data_roots,
val_ann_files=val_ann_files,
test_data_roots=test_data_roots,
test_ann_files=test_ann_files,
unlabeled_data_roots=unlabeled_data_roots,
unlabeled_file_list=unlabeled_file_list,
**kwargs,
)