Source code for otx.algorithms.visual_prompting.tasks.inference

"""Visual Prompting Task."""

# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import ctypes
import io
import json
import os
import shutil
import tempfile
import time
import warnings
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union

import openvino as ov
import torch
from omegaconf import DictConfig, ListConfig
from openvino.tools import mo
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger

from otx.algorithms.common.configs.training_base import TrainType
from otx.algorithms.common.utils import set_random_seed
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.callbacks import (
    InferenceCallback,
    ZeroShotInferenceCallback,
)
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.config import (
    get_visual_promtping_config,
)
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets import (
    OTXVisualPromptingDataModule,
)
from otx.algorithms.visual_prompting.configs.base.configuration import (
    VisualPromptingBaseConfig,
)
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import (
    ModelEntity,
    ModelFormat,
    ModelOptimizationType,
    ModelPrecision,
    OptimizationMethod,
)
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.train_parameters import TrainParameters
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
from otx.api.usecases.tasks.interfaces.evaluate_interface import IEvaluationTask
from otx.api.usecases.tasks.interfaces.export_interface import ExportType, IExportTask
from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask
from otx.api.usecases.tasks.interfaces.unload_interface import IUnload
from otx.utils.logger import get_logger

logger = get_logger()


# pylint: disable=too-many-instance-attributes
[docs] class InferenceTask(IInferenceTask, IEvaluationTask, IExportTask, IUnload): """Base Visual Prompting Task. Train, Infer, and Export an Visual Prompting Task. Args: task_environment (TaskEnvironment): OTX Task environment. output_path (Optional[str]): output path where task output are saved. """ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None) -> None: torch.backends.cudnn.enabled = True logger.info("Initializing the task environment.") self.task_environment = task_environment self.task_type = task_environment.model_template.task_type self.model_name = task_environment.model_template.name self.labels = task_environment.get_labels() self.hyper_parameters: VisualPromptingBaseConfig = self.task_environment.get_hyper_parameters() self.train_type = self.hyper_parameters.algo_backend.train_type # type: ignore[attr-defined] template_file_path = task_environment.model_template.model_template_path self.base_dir = os.path.abspath(os.path.dirname(template_file_path)) # Hyperparameters. self._work_dir_is_temp = False self.output_path = output_path self.mode = "train" if task_environment.model is not None and task_environment.model.train_dataset is None: self.mode = "export" if self.output_path is None: self.output_path = tempfile.mkdtemp(prefix="otx-visual_prompting") self._work_dir_is_temp = True self.mode = "inference" self.config = self.get_config() # Set default model attributes. self.optimization_methods: List[OptimizationMethod] = [] self.precision = [ModelPrecision.FP32] self.optimization_type = ModelOptimizationType.MO self.trainer: Trainer self._model_ckpt: Optional[str] = None self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
[docs] def set_seed(self): """Set seed and deterministic.""" if self.seed is None: # If the seed is not present via task.train, it will be found in the recipe. self.seed = self.config.get("seed", 5) if not self.deterministic: # deterministic is the same. self.deterministic = self.config.get("deterministic", False) self.config["seed"] = self.seed self.config["deterministic"] = self.deterministic set_random_seed(self.seed, logger, self.deterministic)
[docs] def get_config(self) -> Union[DictConfig, ListConfig]: """Get Visual Prompting Config from task environment. Returns: Union[DictConfig, ListConfig]: Visual Prompting config. """ # set checkpoints model_checkpoint: Optional[str] = None resume_from_checkpoint: Optional[str] = None if self.mode == "train" and self.task_environment.model is not None: # when args.load_weights or args.resume_from is set checkpoint_path = str(self.task_environment.model.model_adapters.get("path", None)) if self.task_environment.model.model_adapters.get("resume", False): resume_from_checkpoint = checkpoint_path else: model_checkpoint = checkpoint_path config = get_visual_promtping_config( task_name=self.model_name, otx_config=self.hyper_parameters, config_dir=self.base_dir, mode=self.mode, model_checkpoint=model_checkpoint, resume_from_checkpoint=resume_from_checkpoint, ) config.dataset.task = "visual_prompting" return config
[docs] def load_model(self, otx_model: Optional[ModelEntity] = None) -> LightningModule: """Create and Load Visual Prompting Module. Currently, load model through `sam_model_registry` because there is only SAM. If other visual prompting model is added, loading model process must be changed. Args: otx_model (Optional[ModelEntity]): OTX Model from the task environment. Returns: LightningModule: Visual prompting model with/without weights. """ def get_model(config: DictConfig, train_type: TrainType, state_dict: Optional[OrderedDict] = None): if config.model.name == "SAM": if train_type == TrainType.Incremental: from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models import ( SegmentAnything as VisualPrompter, ) elif train_type == TrainType.Zeroshot: from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models import ( # type: ignore[assignment] # noqa: E501 ZeroShotSegmentAnything as VisualPrompter, ) model = VisualPrompter(config=config, state_dict=state_dict) else: raise NotImplementedError( (f"Current selected model {config.model.name} is not implemented. " f"Use SAM instead.") ) return model state_dict = None if otx_model is None: logger.info( "No trained model in project yet. Created new model with '%s'", self.model_name, ) elif otx_model.model_adapters.get("resume", False): # If resuming, pass this part to load checkpoint in Trainer logger.info(f"To resume {otx_model.model_adapters.get('path')}, the checkpoint will be loaded in Trainer.") else: # Load state_dict buffer = io.BytesIO(otx_model.get_data("weights.pth")) model_data = torch.load(buffer, map_location=torch.device("cpu")) if model_data.get("state_dict", None) and model_data.get("pytorch-lightning_version", None): # Load state_dict from pytorch lightning checkpoint or weights.pth saved by visual prompting task # In pytorch lightning checkpoint, there are metas: epoch, global_step, pytorch-lightning_version, # state_dict, loops, callbacks, optimizer_states, lr_schedulers, hparams_name, hyper_parameters. # To confirm if it is from pytorch lightning, check if one or two of them is in model_data. state_dict = model_data["state_dict"] elif model_data.get("model", None) and model_data.get("config", None): # Load state_dict from checkpoint saved by otx other tasks if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]: logger.warning( "Backbone of the model in the Task Environment is different from the one in the template. " f"creating model with backbone={model_data['config']['model']['backbone']}" ) self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"] state_dict = model_data["model"] else: # Load state_dict from naive pytorch checkpoint state_dict = model_data try: model = get_model(config=self.config, train_type=self.train_type, state_dict=state_dict) logger.info("Complete to load model.") except BaseException as exception: raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception return model
def cancel_training(self) -> None: # noqa: D102 raise NotImplementedError
[docs] def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameters) -> DatasetEntity: """Perform inference on a dataset. Args: dataset (DatasetEntity): Dataset to infer. inference_parameters (InferenceParameters): Inference parameters. Returns: DatasetEntity: Output dataset with predictions. """ logger.info("Performing inference on the validation set using the base torch model.") self.model = self.load_model(otx_model=self.task_environment.model) datamodule = OTXVisualPromptingDataModule( config=self.config.dataset, dataset=dataset, train_type=self.train_type ) logger.info("Inference Configs '%s'", self.config) # Callbacks inference_callback = InferenceCallback(otx_dataset=dataset) callbacks = [TQDMProgressBar(), inference_callback] self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks) self.trainer.predict(model=self.model, datamodule=datamodule) return inference_callback.otx_dataset
[docs] def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None) -> None: """Evaluate the performance on a result set. Args: output_resultset (ResultSetEntity): Result Set from which the performance is evaluated. evaluation_metric (Optional[str], optional): Evaluation metric. Defaults to None. Instead, metric is chosen depending on the task type. """ metric = MetricsHelper.compute_dice_averaged_over_pixels(output_resultset) logger.info(f"mDice after evaluation: {metric.overall_dice.value}") output_resultset.performance = metric.get_performance() logger.info("Evaluation completed")
def _export_to_onnx(self, onnx_path: Dict[str, str]): """Export model to ONNX. Args: onnx_path (Dict[str, str]): Paths to save ONNX models. """ height = width = self.config.model.image_size for module, path in onnx_path.items(): if module == "visual_prompting_image_encoder": dummy_inputs = {"images": torch.randn(1, 3, height, width, dtype=torch.float32)} output_names = ["image_embeddings"] dynamic_axes = None model_to_export = self.model.image_encoder else: # sam without backbone embed_dim = self.model.prompt_encoder.embed_dim embed_size = self.model.prompt_encoder.image_embedding_size mask_input_size = [4 * x for x in embed_size] dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, } dummy_inputs = { "image_embeddings": torch.zeros(1, embed_dim, *embed_size, dtype=torch.float32), "point_coords": torch.randint(low=0, high=1024, size=(1, 2, 2), dtype=torch.float32), "point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float32), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float32), "has_mask_input": torch.tensor([[1]], dtype=torch.float32), "orig_size": torch.randint(low=256, high=2048, size=(1, 2), dtype=torch.int64), } output_names = ["upscaled_masks", "iou_predictions", "low_res_masks"] model_to_export = self.model with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) with open(path, "wb") as f: torch.onnx.export( model_to_export, tuple(dummy_inputs.values()), f, export_params=True, verbose=False, opset_version=13, do_constant_folding=True, input_names=list(dummy_inputs.keys()), output_names=output_names, dynamic_axes=dynamic_axes, )
[docs] def export( # noqa: D102 self, export_type: ExportType, output_model: ModelEntity, precision: ModelPrecision = ModelPrecision.FP32, dump_features: bool = False, ) -> None: """Export model to OpenVINO IR. When SAM gets an image for inference, image encoder runs just once to get image embedding. After that, prompt encoder + mask decoder runs repeatedly to get mask prediction. For this case, SAM should be divided into two parts, image encoder and prompt encoder + mask decoder. Args: export_type (ExportType): Export type should be ExportType.OPENVINO output_model (ModelEntity): The model entity in which to write the OpenVINO IR data precision (bool): Output model weights and inference precision dump_features (bool): Flag to return "feature_vector" and "saliency_map". Raises: Exception: If export_type is not ExportType.OPENVINO """ if dump_features: logger.warning( "Feature dumping is not implemented for the visual prompting task." "The saliency maps and representation vector outputs will not be dumped in the exported model." ) self.model = self.load_model(otx_model=self.task_environment.model) if export_type == ExportType.ONNX: output_model.model_format = ModelFormat.ONNX output_model.optimization_type = ModelOptimizationType.ONNX if precision == ModelPrecision.FP16: raise RuntimeError("Export to FP16 ONNX is not supported") elif export_type == ExportType.OPENVINO: output_model.model_format = ModelFormat.OPENVINO output_model.optimization_type = ModelOptimizationType.MO else: raise RuntimeError(f"not supported export type {export_type}") self.precision[0] = precision output_model.has_xai = dump_features logger.info("Exporting to the OpenVINO model.") onnx_path = { "visual_prompting_image_encoder": os.path.join(self.output_path, "visual_prompting_image_encoder.onnx"), "visual_prompting_decoder": os.path.join(self.output_path, "visual_prompting_decoder.onnx"), } self._export_to_onnx(onnx_path) if export_type == ExportType.ONNX: for module, path in onnx_path.items(): with open(path, "rb") as file: output_model.set_data(f"{module}.onnx", file.read()) else: for module, path in onnx_path.items(): mo_args: Dict[str, Any] = {"input_model": path} if module == "visual_prompting_image_encoder": mo_args.update( { "mean_values": list(self.config.dataset.normalize.mean), "scale_values": list(self.config.dataset.normalize.std), } ) if precision == ModelPrecision.FP16: mo_args.update({"compress_to_fp16": True}) ov_model = mo.convert_model(**mo_args) ov.save_model(ov_model, os.path.join(self.output_path, f"{module}.xml")) with open(path.replace(".onnx", ".bin"), "rb") as file: output_model.set_data(f"{module}.bin", file.read()) with open(path.replace(".onnx", ".xml"), "rb") as file: output_model.set_data(f"{module}.xml", file.read()) output_model.precision = self.precision output_model.optimization_methods = self.optimization_methods output_model.set_data("label_schema.json", label_schema_to_bytes(self.task_environment.label_schema)) self._set_metadata(output_model)
[docs] def model_info(self) -> Dict: """Return model info to save the model weights. Returns: Dict: Model info. """ if not self._model_ckpt: logger.warn("model checkpoint is not set, return empty dictionary.") return {} return torch.load(self._model_ckpt, map_location="cpu")
[docs] def save_model(self, output_model: ModelEntity) -> None: """Save the model after training is completed. Args: output_model (ModelEntity): Output model onto which the weights are saved. """ logger.info("Saving the model weights.") model_info = self.model_info() buffer = io.BytesIO() torch.save(model_info, buffer) output_model.set_data("weights.pth", buffer.getvalue()) output_model.set_data("label_schema.json", label_schema_to_bytes(self.task_environment.label_schema)) output_model.precision = self.precision output_model.optimization_methods = self.optimization_methods
def _set_metadata(self, output_model: ModelEntity) -> None: """Set metadata to the output model.""" metadata = {"image_size": int(self.config.dataset.image_size)} # Set the task type for inferencer metadata["task"] = str(self.task_type).lower().split("_")[-1] # type: ignore output_model.set_data("metadata", json.dumps(metadata).encode()) @staticmethod def _is_docker() -> bool: raise NotImplementedError
[docs] def unload(self) -> None: """Unload the task.""" self.cleanup() if self._is_docker(): logger.warning("Got unload request. Unloading models. Throwing Segmentation Fault on purpose") ctypes.string_at(0) else: logger.warning("Got unload request, but not on Docker. Only clearing CUDA cache") torch.cuda.empty_cache() logger.warning( "Done unloading. Torch is still occupying %f bytes of GPU memory", torch.cuda.memory_allocated(), )
[docs] def cleanup(self) -> None: """Clean up work directory.""" if self._work_dir_is_temp: self._delete_scratch_space()
def _delete_scratch_space(self) -> None: """Remove model checkpoints and otx logs.""" if os.path.exists(self.output_path): shutil.rmtree(self.output_path, ignore_errors=False)
[docs] class ZeroShotTask(InferenceTask): """Learn task for Zero-shot learning. **There are two ways to be decided: 1. use it independently <-- temporarily current setting 2. use it depending on template The objective of this task is to get reference features and export it with decoder modules. """ def train( # noqa: D102 self, dataset: DatasetEntity, output_model: ModelEntity, train_parameters: TrainParameters, seed: Optional[int] = None, deterministic: bool = False, ) -> None: logger.info("Training the model.") self.seed = seed self.deterministic = deterministic self.set_seed() self.config.trainer.deterministic = "warn" if deterministic else deterministic logger.info(f"Training Configs {self.config}") self.model = self.load_model(otx_model=self.task_environment.model) datamodule = OTXVisualPromptingDataModule( config=self.config.dataset, dataset=dataset, train_type=self.train_type ) self.trainer = Trainer( logger=CSVLogger(save_dir=self.output_path, name=".", version=self.timestamp), **self.config.trainer ) self.trainer.fit(model=self.model, datamodule=datamodule) # save resulting model self.save_model(output_model)
[docs] def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameters) -> DatasetEntity: """Perform inference on a dataset. Args: dataset (DatasetEntity): Dataset to infer. inference_parameters (InferenceParameters): Inference parameters. Returns: DatasetEntity: Output dataset with predictions. """ logger.info("Performing inference on the validation set using the base torch model.") self.model = self.load_model(otx_model=self.task_environment.model) datamodule = OTXVisualPromptingDataModule( config=self.config.dataset, dataset=dataset, train_type=self.train_type ) logger.info("Inference Configs '%s'", self.config) # Callbacks inference_callback = ZeroShotInferenceCallback( otx_dataset=dataset, label_schema=self.task_environment.label_schema ) callbacks = [TQDMProgressBar(), inference_callback] self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks) self.trainer.predict(model=self.model, datamodule=datamodule) return inference_callback.otx_dataset
[docs] def save_model(self, output_model: ModelEntity) -> None: """Save the model after training is completed. Args: output_model (ModelEntity): Output model onto which the weights are saved. """ logger.info("Saving the model weights and reference features.") model_info = self.model.state_dict() model_info.pop("reference_info.reference_feats") model_info.pop("reference_info.used_indices") buffer = io.BytesIO() torch.save(model_info, buffer) output_model.set_data("weights.pth", buffer.getvalue()) output_model.set_data("label_schema.json", label_schema_to_bytes(self.task_environment.label_schema)) output_model.precision = self.precision output_model.optimization_methods = self.optimization_methods