Source code for otx.algorithms.common.tasks.nncf_task

"""BaseTask for NNCF."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.


import io
import json
import os
from copy import deepcopy
from typing import Dict, List, Optional

import torch
from mmcv.utils import ConfigDict

import otx.algorithms.common.adapters.mmcv.nncf.patches  # noqa: F401  # pylint: disable=unused-import
from otx.algorithms.common.adapters.mmcv.utils import (
    get_configs_by_keys,
    remove_from_config,
    remove_from_configs_by_type,
)
from otx.algorithms.common.adapters.nncf import (
    check_nncf_is_enabled,
    is_accuracy_aware_training_set,
)
from otx.algorithms.common.adapters.nncf.config import compose_nncf_config
from otx.algorithms.common.utils.callback import OptimizationProgressCallback
from otx.algorithms.common.utils.data import get_dataset
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import (
    ModelEntity,
    ModelFormat,
    ModelOptimizationType,
    ModelPrecision,
    OptimizationMethod,
)
from otx.api.entities.optimization_parameters import (
    OptimizationParameters,
    default_progress_callback,
)
from otx.api.entities.subset import Subset
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.tasks.interfaces.optimization_interface import (
    IOptimizationTask,
    OptimizationType,
)
from otx.utils.logger import get_logger

logger = get_logger()


[docs] class NNCFBaseTask(IOptimizationTask): # pylint: disable=too-many-instance-attributes """NNCFBaseTask.""" def __init__(self): check_nncf_is_enabled() self._nncf_data_to_build = None self._nncf_state_dict_to_build: Dict[str, torch.Tensor] = {} self._nncf_preset = None self._optimization_methods: List[OptimizationMethod] = [] self._precision = [ModelPrecision.FP32] # Extra control variables. self._training_work_dir = None self._is_training = False self._should_stop = False self._optimization_type = ModelOptimizationType.NNCF self._time_monitor = None # Variables will be set in training backend task self._data_cfg = None self._model_ckpt = None self._model_dir = None self._labels = None self._recipe_cfg = None self._hyperparams = None self._task_environment = None logger.info("Task initialization completed") def _set_attributes_by_hyperparams(self): quantization = self._hyperparams.nncf_optimization.enable_quantization pruning = self._hyperparams.nncf_optimization.enable_pruning if quantization and pruning: self._nncf_preset = "nncf_quantization_pruning" self._optimization_methods = [ OptimizationMethod.QUANTIZATION, OptimizationMethod.FILTER_PRUNING, ] self._precision = [ModelPrecision.INT8] return if quantization and not pruning: self._nncf_preset = "nncf_quantization" self._optimization_methods = [OptimizationMethod.QUANTIZATION] self._precision = [ModelPrecision.INT8] return if not quantization and pruning: self._nncf_preset = "nncf_pruning" self._optimization_methods = [OptimizationMethod.FILTER_PRUNING] self._precision = [ModelPrecision.FP32] return # FIXEME: Error rasing should be re-enabled after Geti issue resolved # raise RuntimeError("Not selected optimization algorithm") logger.warning("Not selected optimization algorithm. Defaults to quantization") self._nncf_preset = "nncf_quantization" self._optimization_methods = [OptimizationMethod.QUANTIZATION] self._precision = [ModelPrecision.INT8] def _init_train_data_cfg(self, dataset: DatasetEntity): logger.info("init data cfg.") data_cfg = ConfigDict(data=ConfigDict()) for cfg_key, subset in zip( ["train", "val"], [Subset.TRAINING, Subset.VALIDATION], ): subset = get_dataset(dataset, subset) if subset: data_cfg.data[cfg_key] = ConfigDict( otx_dataset=subset, labels=self._labels, ) return data_cfg def _init_nncf_cfg(self): nncf_config_path = os.path.join(self._model_dir, "compression_config.json") with open(nncf_config_path, encoding="UTF-8") as nncf_config_file: common_nncf_config = json.load(nncf_config_file) optimization_config = compose_nncf_config(common_nncf_config, [self._nncf_preset]) max_acc_drop = self._hyperparams.nncf_optimization.maximal_accuracy_degradation / 100 if "accuracy_aware_training" in optimization_config["nncf_config"]: # Update maximal_absolute_accuracy_degradation ( optimization_config["nncf_config"]["accuracy_aware_training"]["params"][ "maximal_absolute_accuracy_degradation" ] ) = max_acc_drop # Force evaluation interval self._config.evaluation.interval = 1 else: logger.info("NNCF config has no accuracy_aware_training parameters") return ConfigDict(optimization_config) def _prepare_optimize(self): assert self._config is not None # TODO: more delicate configuration change control in OTX side # last batch size of 1 causes undefined behaviour for batch normalization # when initializing and training NNCF if self._data_cfg is not None: data_loader = self._config.data.get("train_dataloader", ConfigDict()) samples_per_gpu = data_loader.get("samples_per_gpu", self._config.data.get("samples_per_gpu")) otx_dataset = get_configs_by_keys(self._data_cfg.data.train, "otx_dataset") assert len(otx_dataset) == 1 otx_dataset = otx_dataset[0] if otx_dataset is not None and len(otx_dataset) % samples_per_gpu == 1: data_loader["drop_last"] = True self._config.data["train_dataloader"] = data_loader # nncf does not suppoer FP16 if "fp16" in self._config: remove_from_config(self._config, "fp16") logger.warning("fp16 option is not supported in NNCF. Switch to fp32.") # FIXME: nncf quantizer does not work with SAMoptimizer optimizer_config = self._config.optimizer_config if optimizer_config.get("type", "OptimizerHook") == "SAMOptimizerHook": optimizer_config.type = "OptimizerHook" logger.warning("Updateed SAMOptimizerHook to OptimizerHook as not supported.") # merge nncf_cfg nncf_cfg = self._init_nncf_cfg() self._config.merge_from_dict(nncf_cfg) # configure nncf nncf_config = self._config.get("nncf_config", {}) if nncf_config.get("target_metric_name", None) is None: metric_name = self._config.evaluation.metric if isinstance(metric_name, list): metric_name = metric_name[0] nncf_config.target_metric_name = metric_name logger.info(f"'target_metric_name' not found in nncf config. Using {metric_name} as target metric") if is_accuracy_aware_training_set(nncf_config): # Prepare runner for Accuracy Aware self._config.runner = { "type": "AccuracyAwareRunner", "nncf_config": nncf_config, } # AccuracyAwareRunner needs to evaluate a model when it needs # unlike other runners counting on periodically evaluated score by 'EvalHook'. # To configure 'interval' to 'max_epoch' makes sure 'EvalHook' not to evaluate # during training. max_epoch = nncf_config.accuracy_aware_training.params.maximal_total_epochs self._config.evaluation.interval = max_epoch # Disable 'AdaptiveTrainSchedulingHook' as training is managed by AccuracyAwareRunner remove_from_configs_by_type(self._config.custom_hooks, "AdaptiveTrainSchedulingHook")
[docs] @staticmethod def model_builder( config, *args, nncf_model_builder, model_config=None, data_config=None, is_export=False, return_compression_ctrl=False, **kwargs, ): """model_builder.""" if model_config is not None or data_config is not None: config = deepcopy(config) if model_config is not None: config.merge_from_dict(model_config) if data_config is not None: config.merge_from_dict(data_config) compression_ctrl, model, = nncf_model_builder( config, distributed=False, *args, **kwargs, ) if is_export: compression_ctrl.prepare_for_export() model.nncf.disable_dynamic_graph_building() if return_compression_ctrl: return compression_ctrl, model return model
def _optimize( self, dataset: DatasetEntity, optimization_parameters: Optional[OptimizationParameters] = None, ): raise NotImplementedError def _optimize_post_hook( self, dataset: DatasetEntity, output_model: ModelEntity, ): pass
[docs] def optimize( self, optimization_type: OptimizationType, dataset: DatasetEntity, output_model: ModelEntity, optimization_parameters: Optional[OptimizationParameters] = None, ): """NNCF Optimization.""" if optimization_type is not OptimizationType.NNCF: raise RuntimeError("NNCF is the only supported optimization") if optimization_parameters is not None: update_progress_callback = optimization_parameters.update_progress else: update_progress_callback = default_progress_callback self._time_monitor = OptimizationProgressCallback( update_progress_callback, loading_stage_progress_percentage=5, initialization_stage_progress_percentage=5, ) self._data_cfg = self._init_train_data_cfg(dataset) self._is_training = True results = self._optimize(dataset, optimization_parameters) # Check for stop signal when training has stopped. # If should_stop is true, training was cancelled if self._should_stop: logger.info("Training cancelled.") self._should_stop = False self._is_training = False return model_ckpt = results.get("final_ckpt") if model_ckpt is None: logger.error("cannot find final checkpoint from the results.") # output_model.model_status = ModelStatus.FAILED return # update checkpoint to the newly trained model self._model_ckpt = model_ckpt self._optimize_post_hook(dataset, output_model) self.save_model(output_model) output_model.model_format = ModelFormat.BASE_FRAMEWORK output_model.optimization_type = self._optimization_type output_model.optimization_methods = self._optimization_methods output_model.precision = self._precision self._is_training = False
def _save_model_post_hook(self, modelinfo): pass
[docs] def save_model(self, output_model: ModelEntity): """Saving model function for NNCF Task.""" assert self._recipe_cfg is not None buffer = io.BytesIO() hyperparams_str = ids_to_strings(cfg_helper.convert(self._hyperparams, dict, enum_to_str=True)) labels = {label.name: label.color.rgb_tuple for label in self._labels} model_ckpt = torch.load(self._model_ckpt, map_location=torch.device("cpu")) modelinfo = { "model": model_ckpt, "config": hyperparams_str, "labels": labels, "VERSION": 1, "meta": { "nncf_enable_compression": True, }, } self._save_model_post_hook(modelinfo) torch.save(modelinfo, buffer) output_model.set_data("weights.pth", buffer.getvalue()) output_model.set_data( "label_schema.json", label_schema_to_bytes(self._task_environment.label_schema), )