Source code for otx.algorithms.classification.task

"""Task of OTX Classification."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import io
import json
import os
from abc import ABC, abstractmethod
from typing import List, Optional

import numpy as np
import torch

from otx.algorithms.classification.configs.base import ClassificationConfig
from otx.algorithms.classification.utils import (
    get_cls_deploy_config,
    get_cls_inferencer_configuration,
    get_cls_model_api_configuration,
    get_hierarchical_label_list,
)
from otx.algorithms.classification.utils import (
    get_multihead_class_info as get_hierarchical_info,
)
from otx.algorithms.common.configs import TrainType
from otx.algorithms.common.configs.configuration_enums import InputSizePreset
from otx.algorithms.common.tasks.base_task import TRAIN_TYPE_DIR_PATH, OTXTask
from otx.algorithms.common.utils import embed_ir_model_data
from otx.algorithms.common.utils.callback import TrainingProgressCallback
from otx.algorithms.common.utils.utils import embed_onnx_model_data
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import (
    InferenceParameters,
)
from otx.api.entities.inference_parameters import (
    default_progress_callback as default_infer_progress_callback,
)
from otx.api.entities.label import LabelEntity
from otx.api.entities.label_schema import LabelGroup
from otx.api.entities.metadata import FloatMetadata, FloatType
from otx.api.entities.metrics import (
    CurveMetric,
    LineChartInfo,
    LineMetricsGroup,
    MetricsGroup,
    Performance,
    ScoreMetric,
)
from otx.api.entities.model import (
    ModelEntity,
    ModelPrecision,
)
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.scored_label import ScoredLabel
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.tensor import TensorEntity
from otx.api.entities.train_parameters import (
    TrainParameters,
)
from otx.api.entities.train_parameters import (
    default_progress_callback as default_train_progress_callback,
)
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.api.utils.dataset_utils import add_saliency_maps_to_dataset_item
from otx.api.utils.labels_utils import get_empty_label
from otx.cli.utils.multi_gpu import is_multigpu_child_process
from otx.core.data.caching.mem_cache_handler import MemCacheHandlerSingleton
from otx.utils.logger import get_logger

logger = get_logger()
RECIPE_TRAIN_TYPE = {
    TrainType.Semisupervised: "semisl.yaml",
    TrainType.Incremental: "incremental.yaml",
    TrainType.Selfsupervised: "selfsl.yaml",
}


[docs] class OTXClassificationTask(OTXTask, ABC): """Task class for OTX classification.""" # pylint: disable=too-many-instance-attributes, too-many-locals, too-many-boolean-expressions def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None): super().__init__(task_environment, output_path) self._task_config = ClassificationConfig self._hyperparams = self._task_environment.get_hyper_parameters(self._task_config) if len(self._task_environment.get_labels(False)) == 1: self._labels = self._task_environment.get_labels(include_empty=True) else: self._labels = self._task_environment.get_labels(include_empty=False) self._empty_label = get_empty_label(self._task_environment.label_schema) self._multilabel = False self._hierarchical = False self._hierarchical_info = None self._selfsl = False self._set_train_mode() self._train_type = self._hyperparams.algo_backend.train_type self._model_dir = os.path.join( os.path.abspath(os.path.dirname(self._task_environment.model_template.model_template_path)), TRAIN_TYPE_DIR_PATH[self._train_type.name], ) if ( self._train_type in RECIPE_TRAIN_TYPE and self._train_type == TrainType.Incremental and not self._multilabel and not self._hierarchical and self._hyperparams.learning_parameters.enable_supcon and not self._model_dir.endswith("supcon") ): self._model_dir = os.path.join(self._model_dir, "supcon") self.data_pipeline_path = os.path.join(self._model_dir, "data_pipeline.py") if self._task_environment.model is not None: self._load_model() if hasattr(self._hyperparams.learning_parameters, "input_size"): input_size_cfg = InputSizePreset(self._hyperparams.learning_parameters.input_size.value) else: input_size_cfg = InputSizePreset.DEFAULT self._input_size = input_size_cfg.tuple if hasattr(self._hyperparams.learning_parameters, "input_size"): input_size_cfg = InputSizePreset(self._hyperparams.learning_parameters.input_size.value) else: input_size_cfg = InputSizePreset.DEFAULT self._input_size = input_size_cfg.tuple def _is_multi_label(self, label_groups: List[LabelGroup], all_labels: List[LabelEntity]): """Check whether the current training mode is multi-label or not.""" # NOTE: In the current Geti, multi-label should have `___` symbol for all group names. find_multilabel_symbol = ["___" in getattr(i, "name", "") for i in label_groups] return ( (len(label_groups) > 1) and (len(label_groups) == len(all_labels)) and (False not in find_multilabel_symbol) ) def _set_train_mode(self): label_groups = self._task_environment.label_schema.get_groups(include_empty=False) all_labels = self._task_environment.label_schema.get_labels(include_empty=False) self._multilabel = self._is_multi_label(label_groups, all_labels) if self._multilabel: logger.info("Classification mode: multilabel") elif len(label_groups) > 1: logger.info("Classification mode: hierarchical") self._hierarchical = True self._hierarchical_info = get_hierarchical_info(self._task_environment.label_schema) if not self._multilabel and not self._hierarchical: logger.info("Classification mode: multiclass") if self._hyperparams.algo_backend.train_type == TrainType.Selfsupervised: self._selfsl = True
[docs] def infer( self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None, ) -> DatasetEntity: """Main infer function of OTX Classification.""" logger.info("infer()") results = self._infer_model(dataset, inference_parameters) prediction_results = zip( results["eval_predictions"], results["feature_vectors"], results["saliency_maps"], ) update_progress_callback = default_infer_progress_callback process_saliency_maps = False explain_predicted_classes = True if inference_parameters is not None: update_progress_callback = inference_parameters.update_progress # type: ignore process_saliency_maps = inference_parameters.process_saliency_maps explain_predicted_classes = inference_parameters.explain_predicted_classes self._add_predictions_to_dataset( prediction_results, dataset, update_progress_callback, process_saliency_maps, explain_predicted_classes ) return dataset
[docs] def train( self, dataset: DatasetEntity, output_model: ModelEntity, train_parameters: Optional[TrainParameters] = None, seed: Optional[int] = None, deterministic: bool = False, ): """Train function for OTX classification task. Actual training is processed by _train_model fucntion """ logger.info("train()") # Check for stop signal when training has stopped. # If should_stop is true, training was cancelled and no new if self._should_stop: logger.info("Training cancelled.") self._should_stop = False self._is_training = False return self.seed = seed self.deterministic = deterministic # Set OTX LoggerHook & Time Monitor if train_parameters: update_progress_callback = train_parameters.update_progress else: update_progress_callback = default_train_progress_callback self._time_monitor = TrainingProgressCallback(update_progress_callback) results = self._train_model(dataset) MemCacheHandlerSingleton.delete() # Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new if self._should_stop: logger.info("Training cancelled.") self._should_stop = False self._is_training = False return # get output model model_ckpt = results.get("final_ckpt") if model_ckpt is None: logger.error("cannot find final checkpoint from the results.") return # update checkpoint to the newly trained model self._model_ckpt = model_ckpt # compose performance statistics training_metrics, final_acc = self._generate_training_metrics(self._learning_curves) # save resulting model self.save_model(output_model) performance = Performance( score=ScoreMetric(value=final_acc, name="accuracy"), dashboard_metrics=training_metrics, ) logger.info(f"Final model performance: {str(performance)}") output_model.performance = performance self._is_training = False logger.info("train done.")
[docs] def export( self, export_type: ExportType, output_model: ModelEntity, precision: ModelPrecision = ModelPrecision.FP32, dump_features: bool = True, ): """Export function of OTX Classification Task.""" logger.info("Exporting the model") self._update_model_export_metadata(output_model, export_type, precision, dump_features) results = self._export_model(precision, export_type, dump_features) outputs = results.get("outputs") logger.debug(f"results of run_task = {outputs}") if outputs is None: raise RuntimeError(results.get("msg")) inference_config = get_cls_inferencer_configuration(self._task_environment.label_schema) extra_model_data = get_cls_model_api_configuration(self._task_environment.label_schema, inference_config) if export_type == ExportType.ONNX: extra_model_data[("model_info", "mean_values")] = results.get("inference_parameters").get("mean_values") extra_model_data[("model_info", "scale_values")] = results.get("inference_parameters").get("scale_values") onnx_file = outputs.get("onnx") embed_onnx_model_data(onnx_file, extra_model_data) with open(onnx_file, "rb") as f: output_model.set_data("model.onnx", f.read()) else: bin_file = outputs.get("bin") xml_file = outputs.get("xml") deploy_cfg = get_cls_deploy_config(self._task_environment.label_schema, inference_config) extra_model_data[("otx_config",)] = json.dumps(deploy_cfg, ensure_ascii=False) embed_ir_model_data(xml_file, extra_model_data) with open(bin_file, "rb") as f: output_model.set_data("openvino.bin", f.read()) with open(xml_file, "rb") as f: output_model.set_data("openvino.xml", f.read()) output_model.set_data( "label_schema.json", label_schema_to_bytes(self._task_environment.label_schema), ) logger.info("Exporting completed")
[docs] def explain( self, dataset: DatasetEntity, explain_parameters: Optional[ExplainParameters] = None, ) -> DatasetEntity: """Main explain function of OTX Classification Task.""" predictions, saliency_maps = self._explain_model( dataset, explain_parameters=explain_parameters, ) update_progress_callback = default_infer_progress_callback process_saliency_maps = False explain_predicted_classes = True if explain_parameters is not None: update_progress_callback = explain_parameters.update_progress # type: ignore process_saliency_maps = explain_parameters.process_saliency_maps explain_predicted_classes = explain_parameters.explain_predicted_classes self._add_explanations_to_dataset( predictions, saliency_maps, dataset, update_progress_callback, process_saliency_maps, explain_predicted_classes, ) logger.info("Explain completed") return dataset
[docs] def evaluate( self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None, ): """Evaluate function of OTX Classification Task.""" logger.info("called evaluate()") if evaluation_metric is not None: logger.warning( f"Requested to use {evaluation_metric} metric, " "but parameter is ignored. Use accuracy instead." ) metric = MetricsHelper.compute_accuracy(output_resultset) logger.info(f"Accuracy after evaluation: {metric.accuracy.value}") output_resultset.performance = metric.get_performance() logger.info("Evaluation completed")
# pylint: disable=too-many-branches, too-many-locals def _add_predictions_to_dataset( self, prediction_results, dataset, update_progress_callback, process_saliency_maps=False, explain_predicted_classes=True, ): """Loop over dataset again to assign predictions.Convert from MMClassification format to OTX format.""" dataset_size = len(dataset) pos_thr = 0.5 label_list = self._labels # Fix the order for hierarchical labels to adjust classes with model outputs if self._hierarchical: label_list = get_hierarchical_label_list(self._hierarchical_info, label_list) for i, (dataset_item, prediction_items) in enumerate(zip(dataset, prediction_results)): prediction_item, feature_vector, saliency_map = prediction_items if any(np.isnan(prediction_item)): logger.info("Nan in prediction_item.") item_labels = self._get_item_labels(prediction_item, pos_thr) dataset_item.append_labels(item_labels) probs = TensorEntity(name="probabilities", numpy=prediction_item.reshape(-1)) dataset_item.append_metadata_item(probs, model=self._task_environment.model) top_idxs = np.argpartition(prediction_item, -2)[-2:] top_probs = prediction_item[top_idxs] active_score_media = FloatMetadata( name="active_score", value=top_probs[1] - top_probs[0], float_type=FloatType.ACTIVE_SCORE ) dataset_item.append_metadata_item(active_score_media, model=self._task_environment.model) if feature_vector is not None: active_score = TensorEntity(name="representation_vector", numpy=feature_vector.reshape(-1)) dataset_item.append_metadata_item(active_score, model=self._task_environment.model) if saliency_map is not None: add_saliency_maps_to_dataset_item( dataset_item=dataset_item, saliency_map=saliency_map, model=self._task_environment.model, labels=label_list, predicted_scored_labels=item_labels, explain_predicted_classes=explain_predicted_classes, process_saliency_maps=process_saliency_maps, ) update_progress_callback(int(i / dataset_size * 100)) # pylint: disable=too-many-locals def _get_item_labels(self, prediction_item, pos_thr): item_labels = [] if self._multilabel: if max(prediction_item) < pos_thr: logger.info("Confidence is smaller than pos_thr, empty_label will be appended to item_labels.") item_labels.append(ScoredLabel(self._empty_label, probability=1.0)) else: for cls_idx, pred_item in enumerate(prediction_item): if pred_item > pos_thr: cls_label = ScoredLabel(self._labels[cls_idx], probability=float(pred_item)) item_labels.append(cls_label) elif self._hierarchical: for head_idx in range(self._hierarchical_info["num_multiclass_heads"]): logits_begin, logits_end = self._hierarchical_info["head_idx_to_logits_range"][str(head_idx)] head_logits = prediction_item[logits_begin:logits_end] head_pred = np.argmax(head_logits) # Assume logits already passed softmax label_str = self._hierarchical_info["all_groups"][head_idx][head_pred] otx_label = next(x for x in self._labels if x.name == label_str) item_labels.append(ScoredLabel(label=otx_label, probability=float(head_logits[head_pred]))) if self._hierarchical_info["num_multilabel_classes"]: head_logits = prediction_item[self._hierarchical_info["num_single_label_classes"] :] for logit_idx, logit in enumerate(head_logits): if logit > pos_thr: # Assume logits already passed sigmoid label_str_idx = self._hierarchical_info["num_multiclass_heads"] + logit_idx label_str = self._hierarchical_info["all_groups"][label_str_idx][0] otx_label = next(x for x in self._labels if x.name == label_str) item_labels.append(ScoredLabel(label=otx_label, probability=float(logit))) item_labels = self._task_environment.label_schema.resolve_labels_greedily(item_labels) if not item_labels: logger.info("item_labels is empty.") item_labels.append(ScoredLabel(self._empty_label, probability=1.0)) else: label_idx = prediction_item.argmax() cls_label = ScoredLabel( self._labels[label_idx], probability=float(prediction_item[label_idx]), ) item_labels.append(cls_label) return item_labels def _add_explanations_to_dataset( self, predictions, saliency_maps, dataset, update_progress_callback, process_saliency_maps, explain_predicted_classes, ): """Loop over dataset again and assign saliency maps.""" dataset_size = len(dataset) label_list = self._labels # Fix the order for hierarchical labels to adjust classes with model outputs if self._hierarchical: label_list = get_hierarchical_label_list(self._hierarchical_info, label_list) for i, (dataset_item, prediction_item, saliency_map) in enumerate(zip(dataset, predictions, saliency_maps)): item_labels = self._get_item_labels(prediction_item, pos_thr=0.5) add_saliency_maps_to_dataset_item( dataset_item=dataset_item, saliency_map=saliency_map, model=self._task_environment.model, labels=label_list, predicted_scored_labels=item_labels, explain_predicted_classes=explain_predicted_classes, process_saliency_maps=process_saliency_maps, ) update_progress_callback(int(i / dataset_size * 100))
[docs] def save_model(self, output_model: ModelEntity): """Save best model weights in ClassificationTrainTask.""" if is_multigpu_child_process(): return logger.info("called save_model") buffer = io.BytesIO() hyperparams_str = ids_to_strings(cfg_helper.convert(self._hyperparams, dict, enum_to_str=True)) labels = {label.name: label.color.rgb_tuple for label in self._labels} model_ckpt = torch.load(self._model_ckpt) modelinfo = { "model": model_ckpt, "config": hyperparams_str, "labels": labels, "input_size": self._input_size, "VERSION": 1, } torch.save(modelinfo, buffer) output_model.set_data("weights.pth", buffer.getvalue()) output_model.set_data( "label_schema.json", label_schema_to_bytes(self._task_environment.label_schema), ) output_model.precision = self._precision
def _generate_training_metrics(self, learning_curves): # pylint: disable=arguments-renamed """Parses the classification logs to get metrics from the latest training run. :return output List[MetricsGroup] """ output: List[MetricsGroup] = [] if self._multilabel: metric_key = "val/accuracy-mlc" elif self._hierarchical: metric_key = "val/MHAcc" else: metric_key = "val/accuracy (%)" # Learning curves best_acc = -1 if learning_curves is None: return output for key, curve in learning_curves.items(): metric_curve = CurveMetric(xs=curve.x, ys=curve.y, name=key) if key == metric_key: best_acc = max(curve.y) visualization_info = LineChartInfo(name=key, x_axis_label="Timestamp", y_axis_label=key) output.append(LineMetricsGroup(metrics=[metric_curve], visualization_info=visualization_info)) return output, best_acc @abstractmethod def _train_model(self, dataset: DatasetEntity): """Train model and return the results.""" raise NotImplementedError @abstractmethod def _export_model(self, precision: ModelPrecision, export_format: ExportType, dump_features: bool): """Export model and return the results.""" raise NotImplementedError @abstractmethod def _explain_model(self, dataset: DatasetEntity, explain_parameters: Optional[ExplainParameters]): """Explain model and return the results.""" raise NotImplementedError @abstractmethod def _infer_model( self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None, ): """Get inference results from dataset.""" raise NotImplementedError