Source code for otx.algorithms.common.tasks.base_task

"""Base task of OTX."""

# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import io
import logging
import os
import shutil
import tempfile
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Any, Dict, Iterable, List, Optional

import torch
from torch import distributed as dist

from otx.algorithms.common.adapters.mmcv.hooks import OTXLoggerHook
from otx.algorithms.common.adapters.mmcv.hooks.cancel_hook import CancelInterfaceHook
from otx.algorithms.common.configs.training_base import TrainType
from otx.algorithms.common.utils import UncopiableDefaultDict, append_dist_rank_suffix, set_random_seed
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.label import LabelEntity
from otx.api.entities.metrics import MetricsGroup
from otx.api.entities.model import ModelEntity, ModelFormat, ModelOptimizationType, ModelPrecision, OptimizationMethod
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.train_parameters import TrainParameters
from otx.api.serialization.label_mapper import LabelSchemaMapper
from otx.api.usecases.reporting.time_monitor_callback import TimeMonitorCallback
from otx.api.usecases.tasks.interfaces.evaluate_interface import IEvaluationTask
from otx.api.usecases.tasks.interfaces.export_interface import ExportType, IExportTask
from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask
from otx.api.usecases.tasks.interfaces.unload_interface import IUnload
from otx.utils.logger import get_logger

TRAIN_TYPE_DIR_PATH = {
    TrainType.Incremental.name: ".",
    TrainType.Selfsupervised.name: "selfsl",
    TrainType.Semisupervised.name: "semisl",
}

logger = get_logger()


[docs] class OnHookInitialized: """OnHookInitialized class.""" def __init__(self, task_instance): self.task_instance = task_instance self.__findable = False # a barrier to block segmentation fault
[docs] def __call__(self, cancel_interface): """Function call in OnHookInitialized.""" if isinstance(self.task_instance, int) and self.__findable: import ctypes # NOTE: BE AWARE OF SEGMENTATION FAULT self.task_instance = ctypes.cast(self.task_instance, ctypes.py_object).value self.task_instance.cancel_hook_initialized(cancel_interface)
def __repr__(self): """Function repr in OnHookInitialized.""" return f"'{__name__}.OnHookInitialized'" def __deepcopy__(self, memo): """Function deepcopy in OnHookInitialized.""" cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result result.task_instance = self.task_instance result.__findable = True # pylint: disable=unused-private-member, protected-access return result def __reduce__(self): """Function reduce in OnHookInitialized.""" return (self.__class__, (id(self.task_instance),))
# pylint: disable=too-many-instance-attributes
[docs] class OTXTask(IInferenceTask, IExportTask, IEvaluationTask, IUnload, ABC): """Base task of OTX.""" def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None): self._config: Dict[Any, Any] = {} self._task_environment = task_environment self._task_type = task_environment.model_template.task_type self._labels = task_environment.get_labels(include_empty=False) self._work_dir_is_temp = False self._output_path = output_path self._output_path = output_path if output_path is not None else self._get_tmp_dir() self._time_monitor: Optional[TimeMonitorCallback] = None self.on_hook_initialized = OnHookInitialized(self) self._learning_curves = UncopiableDefaultDict(OTXLoggerHook.Curve) self._model_label_schema: List[LabelEntity] = [] self._resume = False self._should_stop: bool = False self.cancel_interface: Optional[CancelInterfaceHook] = None self.reserved_cancel = False self._model_ckpt = None self._precision = [ModelPrecision.FP32] self._optimization_methods: List[OptimizationMethod] = [] self._is_training = False self.seed: Optional[int] = None self.deterministic: bool = False self.override_configs: Dict[str, str] = {} # This is for hpo, and this should be removed self.project_path = self._output_path if self._is_distributed_training(): self._setup_distributed_training() @staticmethod def _is_distributed_training(): multi_gpu_env = ["MASTER_ADDR", "MASTER_PORT", "LOCAL_WORLD_SIZE", "WORLD_SIZE", "LOCAL_RANK", "RANK"] for env in multi_gpu_env: if env not in os.environ: return False return torch.cuda.is_available() @staticmethod def _setup_distributed_training(): if not dist.is_initialized(): torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) dist.init_process_group( backend="nccl", init_method="env://", timeout=timedelta(seconds=int(os.environ.get("TORCH_DIST_TIMEOUT", 60))), ) rank = dist.get_rank() logger.info(f"Dist info: rank {rank} / {dist.get_world_size()} world_size") if rank != 0: logging.disable(logging.WARNING) def _get_tmp_dir(self): self._work_dir_is_temp = True # If training is excuted with torchrun, set all trainings' output directory same if "TORCHELASTIC_RUN_ID" in os.environ: return os.path.join(tempfile.gettempdir(), f"OTX-task-torchelastic-{os.environ['TORCHELASTIC_RUN_ID']}") return tempfile.mkdtemp(prefix="OTX-task-") def _load_model(self): """Loading model from checkpoint.""" def _load_model_label_schema(model: Optional[ModelEntity]): # If a model has been trained and saved for the task already, create empty model and load weights here if model and "label_schema.json" in model.model_adapters: import json buffer = json.loads(model.get_data("label_schema.json").decode("utf-8")) model_label_schema = LabelSchemaMapper().backward(buffer) return model_label_schema.get_labels(include_empty=False) return self._labels logger.info("loading the model from the task env.") model = self._task_environment.model state_dict = self._load_model_ckpt(model) if state_dict: self._model_ckpt = append_dist_rank_suffix(os.path.join(self._output_path, "env_model_ckpt.pth")) if os.path.exists(self._model_ckpt): os.remove(self._model_ckpt) torch.save(state_dict, self._model_ckpt) self._model_label_schema = _load_model_label_schema(model) if model is not None: self._resume = model.model_adapters.get("resume", False) def _load_model_ckpt(self, model: Optional[ModelEntity]): if model and "weights.pth" in model.model_adapters: # If a model has been trained and saved for the task already, create empty model and load weights here buffer = io.BytesIO(model.get_data("weights.pth")) model_data = torch.load(buffer, map_location=torch.device("cpu")) return model_data return None
[docs] @abstractmethod def train( self, dataset: DatasetEntity, output_model: ModelEntity, train_parameters: Optional[TrainParameters] = None, seed: Optional[int] = None, deterministic: bool = False, ): """Train function for OTX task.""" raise NotImplementedError
[docs] @abstractmethod def infer( self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None, ) -> DatasetEntity: """Main infer function.""" raise NotImplementedError
[docs] @abstractmethod def export( self, export_type: ExportType, output_model: ModelEntity, precision: ModelPrecision = ModelPrecision.FP32, dump_features: bool = True, ): """Export function of OTX Task.""" raise NotImplementedError
[docs] @abstractmethod def explain( self, dataset: DatasetEntity, explain_parameters: Optional[ExplainParameters] = None, ) -> DatasetEntity: """Main explain function of OTX Task.""" raise NotImplementedError
[docs] @abstractmethod def evaluate( self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None, ): """Evaluate function of OTX Task.""" raise NotImplementedError
@staticmethod @abstractmethod def _generate_training_metrics(learning_curves, scores) -> Iterable[MetricsGroup[Any, Any]]: """Get Training metrics (epochs & scores). Parses the training logs to get metrics from the latest training run :return output List[MetricsGroup] """ raise NotImplementedError
[docs] @abstractmethod def save_model(self, output_model: ModelEntity): """Save best model weights in trining task.""" raise NotImplementedError
[docs] def cancel_training(self): """Cancel training function in trining task. Sends a cancel training signal to gracefully stop the optimizer. The signal consists of creating a '.stop_training' file in the current work_dir. The runner checks for this file periodically. The stopping mechanism allows stopping after each iteration, but validation will still be carried out. Stopping will therefore take some time. """ logger.info("Cancel training requested.") self._should_stop = True if self.cancel_interface is not None: self.cancel_interface.cancel() else: logger.info("but training was not started yet. reserved it to cancel") self.reserved_cancel = True
[docs] def cancel_hook_initialized(self, cancel_interface: CancelInterfaceHook): """Initialization of cancel_interface hook.""" logger.info("cancel hook is initialized") self.cancel_interface = cancel_interface if self.reserved_cancel and self.cancel_interface: self.cancel_interface.cancel()
[docs] def cleanup(self): """Clean up work directory if user specified it.""" if self._work_dir_is_temp: self._delete_scratch_space()
def _delete_scratch_space(self): """Remove model checkpoints and otx logs.""" if os.path.exists(self._output_path): shutil.rmtree(self._output_path, ignore_errors=False)
[docs] def unload(self): """Unload the task.""" self.cleanup() if self._is_docker(): logger.warning("Got unload request. Unloading models. Throwing Segmentation Fault on purpose") import ctypes ctypes.string_at(0) else: logger.warning("Got unload request, but not on Docker. Only clearing CUDA cache") torch.cuda.empty_cache() logger.warning( f"Done unloading. " f"Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory" )
@staticmethod def _is_docker(): """Checks whether the task runs in docker container. :return bool: True if task runs in docker """ path = "/proc/self/cgroup" is_in_docker = False if os.path.isfile(path): with open(path, encoding="UTF-8") as f: is_in_docker = is_in_docker or any("docker" in line for line in f) is_in_docker = is_in_docker or os.path.exists("/.dockerenv") return is_in_docker
[docs] def set_seed(self): """Set seed and deterministic.""" if self.seed is None: # If the seed is not present via task.train, it will be found in the recipe. self.seed = self.config.get("seed", 5) if not self.deterministic: # deterministic is the same. self.deterministic = self.config.get("deterministic", False) self.config["seed"] = self.seed self.config["deterministic"] = self.deterministic set_random_seed(self.seed, logger, self.deterministic)
@property def config(self): """Config of OTX task.""" return self._config @config.setter def config(self, config: Dict[Any, Any]): self._config = config def _update_model_export_metadata( self, output_model: ModelEntity, export_type: ExportType, precision: ModelPrecision, dump_features: bool ) -> None: """Updates a model entity with format and optimization related attributes.""" if export_type == ExportType.ONNX: output_model.model_format = ModelFormat.ONNX output_model.optimization_type = ModelOptimizationType.ONNX if precision == ModelPrecision.FP16: raise RuntimeError("Export to FP16 ONNX is not supported") elif export_type == ExportType.OPENVINO: output_model.model_format = ModelFormat.OPENVINO output_model.optimization_type = ModelOptimizationType.MO else: raise RuntimeError(f"not supported export type {export_type}") output_model.has_xai = dump_features output_model.optimization_methods = self._optimization_methods output_model.precision = [precision]