Source code for otx.algorithms.anomaly.tasks.nncf

"""Anomaly Classification Task."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from __future__ import annotations

import io
import json
import os
import re
from typing import Dict, Optional

import torch
from anomalib.models import AnomalyModule, get_model
from anomalib.post_processing import NormalizationMethod, ThresholdMethod
from anomalib.utils.callbacks import (
    MetricsConfigurationCallback,
    MinMaxNormalizationCallback,
    PostProcessingConfigurationCallback,
)
from anomalib.utils.callbacks.nncf.callback import NNCFCallback
from anomalib.utils.callbacks.nncf.utils import (
    compose_nncf_config,
    is_state_nncf,
    wrap_nncf_model,
)
from pytorch_lightning import Trainer
from torch.utils.data.dataloader import DataLoader

from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import (
    ModelEntity,
    ModelFormat,
    ModelOptimizationType,
    ModelPrecision,
    OptimizationMethod,
)
from otx.api.entities.optimization_parameters import OptimizationParameters
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.usecases.tasks.interfaces.optimization_interface import (
    IOptimizationTask,
    OptimizationType,
)
from otx.utils.logger import get_logger

from .inference import InferenceTask

logger = get_logger()


[docs] class NNCFTask(InferenceTask, IOptimizationTask): """Base Anomaly Task.""" def __init__(self, task_environment: TaskEnvironment, **kwargs) -> None: """Task for compressing models using NNCF. Args: task_environment (TaskEnvironment): OTX Task environment. **kwargs: Addition keyword arguments. """ self.compression_ctrl = None self.nncf_preset = "nncf_quantization" super().__init__(task_environment, **kwargs) self.optimization_type = ModelOptimizationType.NNCF def _set_attributes_by_hyperparams(self): quantization = self.hyper_parameters.nncf_optimization.enable_quantization pruning = self.hyper_parameters.nncf_optimization.enable_pruning if quantization and pruning: self.nncf_preset = "nncf_quantization_pruning" self.optimization_methods = [ OptimizationMethod.QUANTIZATION, OptimizationMethod.FILTER_PRUNING, ] self.precision = [ModelPrecision.INT8] return if quantization and not pruning: self.nncf_preset = "nncf_quantization" self.optimization_methods = [OptimizationMethod.QUANTIZATION] self.precision = [ModelPrecision.INT8] return if not quantization and pruning: self.nncf_preset = "nncf_pruning" self.optimization_methods = [OptimizationMethod.FILTER_PRUNING] self.precision = [ModelPrecision.FP32] return raise RuntimeError("Not selected optimization algorithm")
[docs] def load_model(self, otx_model: Optional[ModelEntity]) -> AnomalyModule: """Create and Load Anomalib Module from OTX Model. This method checks if the task environment has a saved OTX Model, and creates one. If the OTX model already exists, it returns the the model with the saved weights. Args: otx_model (Optional[ModelEntity]): OTX Model from the task environment. Returns: AnomalyModule: Anomalib classification or segmentation model with/without weights. """ nncf_config_path = os.path.join(self.base_dir, "compression_config.json") with open(nncf_config_path, encoding="utf8") as nncf_config_file: common_nncf_config = json.load(nncf_config_file) self._set_attributes_by_hyperparams() self.optimization_config = compose_nncf_config(common_nncf_config, [self.nncf_preset]) self.config.merge_with(self.optimization_config) model = get_model(config=self.config) if otx_model is None: raise ValueError("No trained model in project. NNCF require pretrained weights to compress the model") buffer = io.BytesIO(otx_model.get_data("weights.pth")) # type: ignore model_data = torch.load(buffer, map_location=torch.device("cpu")) if is_state_nncf(model_data): logger.info("Loaded model weights from Task Environment and wrapped by NNCF") # Fix name mismatch for wrapped model by pytorch_lighting nncf_modules = {} pl_modules = {} for key in model_data["model"].keys(): if key.startswith("model."): new_key = key.replace("model.", "") res = re.search(r"(\w+)_feature_extractor\.(.*)", new_key) if res: new_key = f"{res.group(1)}_model.feature_extractor.{res.group(2)}" nncf_modules[new_key] = model_data["model"][key] else: pl_modules[key] = model_data["model"][key] model_data["model"] = nncf_modules dataloader: DataLoader | None = None if hasattr(self, "trainer") and hasattr(self.trainer, "datamodule"): if self.trainer.datamodule.train_dataset is not None: dataloader = self.trainer.datamodule.train_dataloader() elif self.trainer.datamodule.test_dataset is not None: dataloader = self.trainer.datamodule.test_dataloader() self.compression_ctrl, model.model = wrap_nncf_model( model.model, self.optimization_config["nncf_config"], dataloader=dataloader, # type:ignore init_state_dict=model_data, ) # Load extra parameters of pytorch_lighting model model.load_state_dict(pl_modules, strict=False) else: try: model.load_state_dict(model_data["model"]) logger.info("Loaded model weights from Task Environment") except BaseException as exception: raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception return model
[docs] def optimize( self, optimization_type: OptimizationType, dataset: DatasetEntity, output_model: ModelEntity, optimization_parameters: Optional[OptimizationParameters] = None, ): """Train the anomaly classification model. Args: optimization_type (OptimizationType): Type of optimization. dataset (DatasetEntity): Input dataset. output_model (ModelEntity): Output model to save the model weights. optimization_parameters (OptimizationParameters): Training parameters """ logger.info("Optimization the model.") if optimization_type is not OptimizationType.NNCF: raise RuntimeError("NNCF is the only supported optimization") datamodule = OTXAnomalyDataModule(config=self.config, dataset=dataset, task_type=self.task_type) nncf_callback = NNCFCallback(config=self.optimization_config["nncf_config"]) metrics_configuration = MetricsConfigurationCallback( task=self.config.dataset.task, image_metrics=self.config.metrics.image, pixel_metrics=self.config.metrics.get("pixel"), ) post_processing_configuration = PostProcessingConfigurationCallback( normalization_method=NormalizationMethod.MIN_MAX, threshold_method=ThresholdMethod.ADAPTIVE, manual_image_threshold=self.config.metrics.threshold.manual_image, manual_pixel_threshold=self.config.metrics.threshold.manual_pixel, ) callbacks = [ ProgressCallback(parameters=optimization_parameters), MinMaxNormalizationCallback(), nncf_callback, metrics_configuration, post_processing_configuration, ] self.trainer = Trainer(**self.config.trainer, logger=False, callbacks=callbacks) self.trainer.fit(model=self.model, datamodule=datamodule) self.compression_ctrl = nncf_callback.nncf_ctrl output_model.model_format = ModelFormat.BASE_FRAMEWORK output_model.optimization_type = ModelOptimizationType.NNCF self.save_model(output_model) logger.info("Training completed.")
[docs] def model_info(self) -> Dict: """Return model info to save the model weights. Returns: Dict: Model info. """ return { "compression_state": self.compression_ctrl.get_compression_state(), # type: ignore "meta": { "config": self.config, "nncf_enable_compression": True, }, "model": self.model.state_dict(), "config": self.get_config(), "VERSION": 1, }
def _export_to_onnx(self, onnx_path: str): """Export model to ONNX. Args: onnx_path (str): path to save ONNX file """ self.compression_ctrl.export_model(onnx_path, "onnx_11") # type: ignore