"""Task of OTX Video Recognition."""
# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import io
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Union
import numpy as np
import torch
from mmcv.utils import ConfigDict
from otx.algorithms.action.configs.base import ActionConfig
from otx.algorithms.common.tasks.base_task import TRAIN_TYPE_DIR_PATH, OTXTask
from otx.algorithms.common.utils.callback import (
InferenceProgressCallback,
TrainingProgressCallback,
)
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
from otx.api.entities.annotation import Annotation
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.metrics import (
BarChartInfo,
BarMetricsGroup,
CurveMetric,
LineChartInfo,
LineMetricsGroup,
MetricsGroup,
ScoreMetric,
VisualizationType,
)
from otx.api.entities.model import (
ModelEntity,
ModelPrecision,
)
from otx.api.entities.model_template import TaskType
from otx.api.entities.result_media import ResultMediaEntity
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.scored_label import ScoredLabel
from otx.api.entities.shapes.rectangle import Rectangle
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.tensor import TensorEntity
from otx.api.entities.train_parameters import TrainParameters, default_progress_callback
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.evaluation.accuracy import Accuracy
from otx.api.usecases.evaluation.f_measure import FMeasure
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.api.utils.vis_utils import get_actmap
from otx.cli.utils.multi_gpu import is_multigpu_child_process
from otx.utils.logger import get_logger
logger = get_logger()
[docs]
class OTXActionTask(OTXTask, ABC):
"""Task class for OTX action."""
# pylint: disable=too-many-instance-attributes, too-many-locals
def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None):
super().__init__(task_environment, output_path)
self._task_config = ActionConfig
self._hyperparams: ConfigDict = task_environment.get_hyper_parameters(self._task_config)
self._train_type = self._hyperparams.algo_backend.train_type
self._model_dir = os.path.join(
os.path.abspath(os.path.dirname(self._task_environment.model_template.model_template_path)),
TRAIN_TYPE_DIR_PATH[self._train_type.name],
)
if hasattr(self._hyperparams, "postprocessing") and hasattr(
self._hyperparams.postprocessing, "confidence_threshold"
):
self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold
else:
self.confidence_threshold = 0.0
if task_environment.model is not None:
self._load_model()
self.data_pipeline_path = os.path.join(self._model_dir, "data_pipeline.py")
[docs]
def train(
self,
dataset: DatasetEntity,
output_model: ModelEntity,
train_parameters: Optional[TrainParameters] = None,
seed: Optional[int] = None,
deterministic: bool = False,
):
"""Train function for OTX action task.
Actual training is processed by _train_model fucntion
"""
logger.info("train()")
# Check for stop signal when training has stopped.
# If should_stop is true, training was cancelled and no new
if self._should_stop:
logger.info("Training cancelled.")
self._should_stop = False
self._is_training = False
return
self.seed = seed
self.deterministic = deterministic
# Set OTX LoggerHook & Time Monitor
if train_parameters:
update_progress_callback = train_parameters.update_progress
else:
update_progress_callback = default_progress_callback
self._time_monitor = TrainingProgressCallback(update_progress_callback)
results = self._train_model(dataset)
# Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new
if self._should_stop:
logger.info("Training cancelled.")
self._should_stop = False
self._is_training = False
return
# get output model
model_ckpt = results.get("final_ckpt")
if model_ckpt is None:
logger.error("cannot find final checkpoint from the results.")
return
# update checkpoint to the newly trained model
self._model_ckpt = model_ckpt
# get prediction on validation set
self._is_training = False
val_dataset = dataset.get_subset(Subset.VALIDATION)
val_preds, val_performance = self._infer_model(val_dataset, InferenceParameters(is_evaluation=True))
preds_val_dataset = val_dataset.with_empty_annotations()
if self._task_type == TaskType.ACTION_CLASSIFICATION:
self._add_cls_predictions_to_dataset(val_preds, preds_val_dataset)
elif self._task_type == TaskType.ACTION_DETECTION:
self._add_det_predictions_to_dataset(val_preds, preds_val_dataset, 0.0)
result_set = ResultSetEntity(
model=output_model,
ground_truth_dataset=val_dataset,
prediction_dataset=preds_val_dataset,
)
metric: Union[Accuracy, FMeasure]
if self._task_type == TaskType.ACTION_CLASSIFICATION:
metric = MetricsHelper.compute_accuracy(result_set)
if self._task_type == TaskType.ACTION_DETECTION:
if self._hyperparams.postprocessing.result_based_confidence_threshold:
best_confidence_threshold = None
logger.info("Adjusting the confidence threshold")
metric = MetricsHelper.compute_f_measure(result_set, vary_confidence_threshold=True)
if metric.best_confidence_threshold:
best_confidence_threshold = metric.best_confidence_threshold.value
if best_confidence_threshold is None:
raise ValueError("Cannot compute metrics: Invalid confidence threshold!")
logger.info(f"Setting confidence threshold to {best_confidence_threshold} based on results")
self.confidence_threshold = best_confidence_threshold
else:
metric = MetricsHelper.compute_f_measure(result_set, vary_confidence_threshold=False)
# compose performance statistics
performance = metric.get_performance()
performance.dashboard_metrics.extend(self._generate_training_metrics(self._learning_curves, val_performance))
logger.info(f"Final model performance: {performance}")
# save resulting model
self.save_model(output_model)
output_model.performance = performance
logger.info("train done.")
@abstractmethod
def _train_model(self, dataset: DatasetEntity):
"""Train model and return the results."""
raise NotImplementedError
[docs]
def infer(
self,
dataset: DatasetEntity,
inference_parameters: Optional[InferenceParameters] = None,
) -> DatasetEntity:
"""Main infer function."""
logger.info("infer()")
update_progress_callback = default_progress_callback
if inference_parameters is not None:
update_progress_callback = inference_parameters.update_progress # type: ignore
self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback)
# If confidence threshold is adaptive then up-to-date value should be stored in the model
# and should not be changed during inference. Otherwise user-specified value should be taken.
if not self._hyperparams.postprocessing.result_based_confidence_threshold:
self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold
logger.info(f"Confidence threshold {self.confidence_threshold}")
prediction_results, _ = self._infer_model(dataset, inference_parameters)
if self._task_type == TaskType.ACTION_CLASSIFICATION:
self._add_cls_predictions_to_dataset(prediction_results, dataset)
elif self._task_type == TaskType.ACTION_DETECTION:
self._add_det_predictions_to_dataset(prediction_results, dataset, self.confidence_threshold)
logger.info("Inference completed")
return dataset
@abstractmethod
def _infer_model(
self,
dataset: DatasetEntity,
inference_parameters: Optional[InferenceParameters] = None,
):
"""Get inference results from dataset."""
raise NotImplementedError
[docs]
def export(
self,
export_type: ExportType,
output_model: ModelEntity,
precision: ModelPrecision = ModelPrecision.FP32,
dump_features: bool = True,
):
"""Export function of OTX Task."""
if dump_features:
raise NotImplementedError(
"Feature dumping is not implemented for the action task."
"The saliency maps and representation vector outputs will not be dumped in the exported model."
)
self._update_model_export_metadata(output_model, export_type, precision, dump_features)
results = self._export_model(precision, export_type, dump_features)
outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
if outputs is None:
raise RuntimeError(results.get("msg"))
if export_type == ExportType.ONNX:
onnx_file = outputs.get("onnx")
with open(onnx_file, "rb") as f:
output_model.set_data("model.onnx", f.read())
else:
bin_file = outputs.get("bin")
xml_file = outputs.get("xml")
with open(bin_file, "rb") as f:
output_model.set_data("openvino.bin", f.read())
with open(xml_file, "rb") as f:
output_model.set_data("openvino.xml", f.read())
output_model.set_data(
"confidence_threshold",
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
)
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
)
logger.info("Exporting completed")
@abstractmethod
def _export_model(self, precision: ModelPrecision, export_format: ExportType, dump_features: bool):
raise NotImplementedError
[docs]
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of OTX Task."""
raise NotImplementedError("Video recognition task don't support otx explain yet.")
[docs]
def evaluate(
self,
output_resultset: ResultSetEntity,
evaluation_metric: Optional[str] = None,
):
"""Evaluate function of OTX Action Task."""
logger.info("called evaluate()")
if evaluation_metric is not None:
logger.warning(
f"Requested to use {evaluation_metric} metric, " "but parameter is ignored. Use F-measure instead."
)
self._remove_empty_frames(output_resultset.ground_truth_dataset)
metric: Union[Accuracy, FMeasure]
if self._task_type == TaskType.ACTION_CLASSIFICATION:
metric = MetricsHelper.compute_accuracy(output_resultset)
if self._task_type == TaskType.ACTION_DETECTION:
metric = MetricsHelper.compute_f_measure(output_resultset)
performance = metric.get_performance()
logger.info(f"Final model performance: {str(performance)}")
output_resultset.performance = metric.get_performance()
logger.info("Evaluation completed")
def _remove_empty_frames(self, dataset: DatasetEntity):
"""Remove empty frame for action detection dataset."""
remove_indices = []
for idx, item in enumerate(dataset):
if item.get_metadata()[0].data.is_empty_frame:
remove_indices.append(idx)
dataset.remove_at_indices(remove_indices)
def _add_cls_predictions_to_dataset(self, prediction_results: Iterable, dataset: DatasetEntity):
"""Loop over dataset again to assign predictions. Convert from MM format to OTX format."""
prediction_results = list(prediction_results)
video_info: Dict[str, int] = {}
for dataset_item in dataset:
video_id = dataset_item.get_metadata()[0].data.video_id
if video_id not in video_info:
video_info[video_id] = len(video_info)
for dataset_item in dataset:
video_id = dataset_item.get_metadata()[0].data.video_id
all_results, feature_vector, saliency_map = prediction_results[video_info[video_id]]
item_labels = []
label = ScoredLabel(label=self._labels[all_results.argmax()], probability=all_results.max())
item_labels.append(label)
dataset_item.append_labels(item_labels)
if feature_vector is not None:
active_score = TensorEntity(name="representation_vector", numpy=feature_vector.reshape(-1))
dataset_item.append_metadata_item(active_score, model=self._task_environment.model)
if saliency_map is not None:
saliency_map = get_actmap(saliency_map, (dataset_item.width, dataset_item.height))
saliency_map_media = ResultMediaEntity(
name="Saliency Map",
type="saliency_map",
annotation_scene=dataset_item.annotation_scene,
numpy=saliency_map,
roi=dataset_item.roi,
)
dataset_item.append_metadata_item(saliency_map_media, model=self._task_environment.model)
def _add_det_predictions_to_dataset(
self, prediction_results: Iterable, dataset: DatasetEntity, confidence_threshold: float = 0.05
):
self._remove_empty_frames(dataset)
for dataset_item, (all_results, feature_vector, saliency_map) in zip(dataset, prediction_results):
shapes = []
for label_idx, detections in enumerate(all_results):
for i in range(detections.shape[0]):
probability = float(detections[i, 4])
coords = detections[i, :4]
if probability < confidence_threshold:
continue
if coords[3] - coords[1] <= 0 or coords[2] - coords[0] <= 0:
continue
assigned_label = [ScoredLabel(self._labels[label_idx], probability=probability)]
shapes.append(
Annotation(
Rectangle(x1=coords[0], y1=coords[1], x2=coords[2], y2=coords[3]),
labels=assigned_label,
)
)
dataset_item.append_annotations(shapes)
if feature_vector is not None:
active_score = TensorEntity(name="representation_vector", numpy=feature_vector.reshape(-1))
dataset_item.append_metadata_item(active_score, model=self._task_environment.model)
if saliency_map is not None:
saliency_map = get_actmap(saliency_map, (dataset_item.width, dataset_item.height))
saliency_map_media = ResultMediaEntity(
name="Saliency Map",
type="saliency_map",
annotation_scene=dataset_item.annotation_scene,
numpy=saliency_map,
roi=dataset_item.roi,
)
dataset_item.append_metadata_item(saliency_map_media, model=self._task_environment.model)
@staticmethod
# TODO Implement proper function for action classification
def _generate_training_metrics(learning_curves, scores, metric_name="mAP") -> Iterable[MetricsGroup[Any, Any]]:
"""Get Training metrics (epochs & scores).
Parses the mmaction logs to get metrics from the latest training run
:return output List[MetricsGroup]
"""
output: List[MetricsGroup] = []
# Learning curves.
for key, curve in learning_curves.items():
len_x, len_y = len(curve.x), len(curve.y)
if len_x != len_y:
logger.warning(f"Learning curve {key} has inconsistent number of coordinates ({len_x} vs {len_y}.")
len_x = min(len_x, len_y)
curve.x = curve.x[:len_x]
curve.y = curve.y[:len_x]
metric_curve = CurveMetric(
xs=np.nan_to_num(curve.x).tolist(),
ys=np.nan_to_num(curve.y).tolist(),
name=key,
)
visualization_info = LineChartInfo(name=key, x_axis_label="Epoch", y_axis_label=key)
output.append(LineMetricsGroup(metrics=[metric_curve], visualization_info=visualization_info))
# Final mAP value on the validation set.
output.append(
BarMetricsGroup(
metrics=[ScoreMetric(value=scores, name=f"{metric_name}")],
visualization_info=BarChartInfo("Validation score", visualization_type=VisualizationType.RADIAL_BAR),
)
)
return output
[docs]
def save_model(self, output_model: ModelEntity):
"""Save best model weights in ActionTrainTask."""
if is_multigpu_child_process():
return
logger.info("called save_model")
buffer = io.BytesIO()
hyperparams_str = ids_to_strings(cfg_helper.convert(self._hyperparams, dict, enum_to_str=True))
labels = {label.name: label.color.rgb_tuple for label in self._labels}
model_ckpt = torch.load(self._model_ckpt)
modelinfo = {
"model": model_ckpt,
"config": hyperparams_str,
"labels": labels,
"confidence_threshold": self.confidence_threshold,
"VERSION": 1,
}
torch.save(modelinfo, buffer)
output_model.set_data("weights.pth", buffer.getvalue())
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
)
output_model.precision = self._precision