Source code for otx.algorithms.segmentation.task

"""Task of OTX Segmentation."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import io
import os
from abc import ABC, abstractmethod
from typing import List, Optional

import numpy as np
import torch
from mmcv.utils import ConfigDict

from otx.algorithms.common.configs.configuration_enums import InputSizePreset
from otx.algorithms.common.configs.training_base import TrainType
from otx.algorithms.common.tasks.base_task import TRAIN_TYPE_DIR_PATH, OTXTask
from otx.algorithms.common.utils.callback import (
    InferenceProgressCallback,
    TrainingProgressCallback,
)
from otx.algorithms.common.utils.ir import embed_ir_model_data
from otx.algorithms.common.utils.utils import embed_onnx_model_data
from otx.algorithms.segmentation.configs.base import SegmentationConfig
from otx.algorithms.segmentation.utils import get_activation_map
from otx.algorithms.segmentation.utils.metadata import get_seg_model_api_configuration
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.inference_parameters import (
    default_progress_callback as default_infer_progress_callback,
)
from otx.api.entities.metrics import (
    CurveMetric,
    InfoMetric,
    LineChartInfo,
    MetricsGroup,
    Performance,
    ScoreMetric,
    VisualizationInfo,
    VisualizationType,
)
from otx.api.entities.model import (
    ModelEntity,
    ModelPrecision,
)
from otx.api.entities.result_media import ResultMediaEntity
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.tensor import TensorEntity
from otx.api.entities.train_parameters import TrainParameters, default_progress_callback
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.api.utils.segmentation_utils import (
    create_annotation_from_segmentation_map,
    create_hard_prediction_from_soft_prediction,
)
from otx.cli.utils.multi_gpu import is_multigpu_child_process
from otx.core.data.caching.mem_cache_handler import MemCacheHandlerSingleton
from otx.utils.logger import get_logger

logger = get_logger()
RECIPE_TRAIN_TYPE = {
    TrainType.Semisupervised: "semisl.py",
    TrainType.Incremental: "incremental.py",
    TrainType.Selfsupervised: "selfsl.py",
}


[docs] class OTXSegmentationTask(OTXTask, ABC): """Task class for OTX segmentation.""" # pylint: disable=too-many-instance-attributes, too-many-locals def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None): super().__init__(task_environment, output_path) self._task_config = SegmentationConfig self._hyperparams: ConfigDict = task_environment.get_hyper_parameters(self._task_config) self._model_name = task_environment.model_template.name self._train_type = self._hyperparams.algo_backend.train_type self.metric = "mDice" self._label_dictionary = dict(enumerate(self._labels, 1)) # It should have same order as model class order self._model_dir = os.path.join( os.path.abspath(os.path.dirname(self._task_environment.model_template.model_template_path)), TRAIN_TYPE_DIR_PATH[self._train_type.name], ) if ( self._train_type in RECIPE_TRAIN_TYPE and self._train_type == TrainType.Incremental and self._hyperparams.learning_parameters.enable_supcon and not self._model_dir.endswith("supcon") ): self._model_dir = os.path.join(self._model_dir, "supcon") if task_environment.model is not None: self._load_model() self.data_pipeline_path = os.path.join(self._model_dir, "data_pipeline.py") if hasattr(self._hyperparams.learning_parameters, "input_size"): input_size_cfg = InputSizePreset(self._hyperparams.learning_parameters.input_size.value) else: input_size_cfg = InputSizePreset.DEFAULT self._input_size = input_size_cfg.tuple
[docs] def infer( self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None, ) -> DatasetEntity: """Main infer function.""" logger.info("infer()") if inference_parameters is not None: update_progress_callback = inference_parameters.update_progress dump_soft_prediction = not inference_parameters.is_evaluation process_soft_prediction = inference_parameters.process_saliency_maps else: update_progress_callback = default_infer_progress_callback dump_soft_prediction = True process_soft_prediction = False update_progress_callback = default_progress_callback if inference_parameters is not None: update_progress_callback = inference_parameters.update_progress # type: ignore self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback) predictions = self._infer_model(dataset, InferenceParameters(is_evaluation=True)) prediction_results = zip(predictions["eval_predictions"], predictions["feature_vectors"]) self._add_predictions_to_dataset(prediction_results, dataset, dump_soft_prediction, process_soft_prediction) logger.info("Inference completed") return dataset
[docs] def train( self, dataset: DatasetEntity, output_model: ModelEntity, train_parameters: Optional[TrainParameters] = None, seed: Optional[int] = None, deterministic: bool = False, ): """Train function for OTX segmentation task. Actual training is processed by _train_model fucntion """ logger.info("train()") # Check for stop signal when training has stopped. # If should_stop is true, training was cancelled and no new if self._should_stop: # type: ignore logger.info("Training cancelled.") self._should_stop = False self._is_training = False return self.seed = seed self.deterministic = deterministic # Set OTX LoggerHook & Time Monitor if train_parameters: update_progress_callback = train_parameters.update_progress else: update_progress_callback = default_progress_callback self._time_monitor = TrainingProgressCallback(update_progress_callback) results = self._train_model(dataset) MemCacheHandlerSingleton.delete() # Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new if self._should_stop: logger.info("Training cancelled.") self._should_stop = False self._is_training = False return # get output model model_ckpt = results.get("final_ckpt") if model_ckpt is None: logger.error("cannot find final checkpoint from the results.") # output_model.model_status = ModelStatus.FAILED return # update checkpoint to the newly trained model self._model_ckpt = model_ckpt # get prediction on validation set self._is_training = False # Get training metrics group from learning curves training_metrics, best_score = self._generate_training_metrics(self._learning_curves) performance = Performance( score=ScoreMetric(value=best_score, name=self.metric), dashboard_metrics=training_metrics, ) logger.info(f"Final model performance: {str(performance)}") # save resulting model self.save_model(output_model) output_model.performance = performance self._is_training = False logger.info("train done.")
[docs] def export( self, export_type: ExportType, output_model: ModelEntity, precision: ModelPrecision = ModelPrecision.FP32, dump_features: bool = True, ): """Export function of OTX Task.""" logger.info("Exporting the model") self._update_model_export_metadata(output_model, export_type, precision, dump_features) results = self._export_model(precision, export_type, dump_features) outputs = results.get("outputs") logger.debug(f"results of run_task = {outputs}") if outputs is None: raise RuntimeError(results.get("msg")) ir_extra_data = get_seg_model_api_configuration(self._task_environment.label_schema, self._hyperparams) if export_type == ExportType.ONNX: ir_extra_data[("model_info", "mean_values")] = results.get("inference_parameters").get("mean_values") ir_extra_data[("model_info", "scale_values")] = results.get("inference_parameters").get("scale_values") onnx_file = outputs.get("onnx") embed_onnx_model_data(onnx_file, ir_extra_data) with open(onnx_file, "rb") as f: output_model.set_data("model.onnx", f.read()) else: bin_file = outputs.get("bin") xml_file = outputs.get("xml") embed_ir_model_data(xml_file, ir_extra_data) with open(bin_file, "rb") as f: output_model.set_data("openvino.bin", f.read()) with open(xml_file, "rb") as f: output_model.set_data("openvino.xml", f.read()) output_model.set_data("label_schema.json", label_schema_to_bytes(self._task_environment.label_schema)) logger.info("Exporting completed")
[docs] def explain( self, dataset: DatasetEntity, explain_parameters: Optional[ExplainParameters] = None, ) -> DatasetEntity: """Main explain function of OTX Task.""" raise NotImplementedError
[docs] def evaluate( self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None, ): """Evaluate function of OTX Segmentation Task.""" logger.info("called evaluate()") if evaluation_metric is not None: logger.warning( f"Requested to use {evaluation_metric} metric, " "but parameter is ignored. Use mDice instead." ) metric = MetricsHelper.compute_dice_averaged_over_pixels(output_resultset) logger.info(f"mDice after evaluation: {metric.overall_dice.value}") output_resultset.performance = metric.get_performance() logger.info("Evaluation completed")
def _add_predictions_to_dataset(self, prediction_results, dataset, dump_soft_prediction, process_soft_prediction): """Loop over dataset again to assign predictions. Convert from MMSegmentation format to OTX format.""" for dataset_item, (prediction, feature_vector) in zip(dataset, prediction_results): soft_prediction = np.transpose(prediction[0], axes=(1, 2, 0)) hard_prediction = create_hard_prediction_from_soft_prediction( soft_prediction=soft_prediction, soft_threshold=self._hyperparams.postprocessing.soft_threshold, blur_strength=self._hyperparams.postprocessing.blur_strength, ) annotations = create_annotation_from_segmentation_map( hard_prediction=hard_prediction, soft_prediction=soft_prediction, label_map=self._label_dictionary, ) dataset_item.append_annotations(annotations=annotations) if feature_vector is not None: active_score = TensorEntity(name="representation_vector", numpy=feature_vector.reshape(-1)) dataset_item.append_metadata_item(active_score, model=self._task_environment.model) if dump_soft_prediction: for label_index, label in self._label_dictionary.items(): current_label_soft_prediction = soft_prediction[:, :, label_index] if process_soft_prediction: current_label_soft_prediction = get_activation_map(current_label_soft_prediction) else: current_label_soft_prediction = (current_label_soft_prediction * 255).astype(np.uint8) result_media = ResultMediaEntity( name=label.name, type="soft_prediction", label=label, annotation_scene=dataset_item.annotation_scene, roi=dataset_item.roi, numpy=current_label_soft_prediction, ) dataset_item.append_metadata_item(result_media, model=self._task_environment.model)
[docs] def save_model(self, output_model: ModelEntity): """Save best model weights in SegmentationTrainTask.""" if is_multigpu_child_process(): return logger.info("called save_model") buffer = io.BytesIO() hyperparams_str = ids_to_strings(cfg_helper.convert(self._hyperparams, dict, enum_to_str=True)) labels = {label.name: label.color.rgb_tuple for label in self._labels} model_ckpt = torch.load(self._model_ckpt) modelinfo = { "model": model_ckpt, "config": hyperparams_str, "labels": labels, "input_size": self._input_size, "VERSION": 1, } torch.save(modelinfo, buffer) output_model.set_data("weights.pth", buffer.getvalue()) output_model.set_data( "label_schema.json", label_schema_to_bytes(self._task_environment.label_schema), ) output_model.precision = self._precision
def _generate_training_metrics(self, learning_curves): """Get Training metrics (epochs & scores). Parses the mmsegmentation logs to get metrics from the latest training run :return output List[MetricsGroup] """ output: List[MetricsGroup] = [] # Model architecture architecture = InfoMetric(name="Model architecture", value=self._model_name) visualization_info_architecture = VisualizationInfo( name="Model architecture", visualisation_type=VisualizationType.TEXT ) output.append( MetricsGroup( metrics=[architecture], visualization_info=visualization_info_architecture, ) ) # Learning curves best_score = -1 for key, curve in learning_curves.items(): metric_curve = CurveMetric(xs=curve.x, ys=curve.y, name=key) if key == f"val/{self.metric}": best_score = max(curve.y) visualization_info = LineChartInfo(name=key, x_axis_label="Epoch", y_axis_label=key) output.append(MetricsGroup(metrics=[metric_curve], visualization_info=visualization_info)) return output, best_score @abstractmethod def _train_model(self, dataset: DatasetEntity): """Train model and return the results.""" raise NotImplementedError @abstractmethod def _infer_model( self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None, ): """Get inference results from dataset.""" raise NotImplementedError @abstractmethod def _export_model(self, precision: ModelPrecision, export_format: ExportType, dump_features: bool): """Export model and return the results.""" raise NotImplementedError @abstractmethod def _explain_model(self, dataset: DatasetEntity, explain_parameters: Optional[ExplainParameters]): """Explain model and return the results.""" raise NotImplementedError