Source code for otx.algorithms.common.adapters.mmcv.nncf.runners

"""AccuracyAwareRunner for NNCF task."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import time
from dataclasses import asdict

from mmcv.runner import RUNNERS
from mmcv.runner.hooks.evaluation import EvalHook
from mmcv.runner.hooks.lr_updater import LrUpdaterHook
from mmcv.runner.utils import get_host_info

from otx.algorithms.common.adapters.mmcv.nncf.hooks import CompressionHook
from otx.algorithms.common.adapters.mmcv.runner import EpochRunnerWithCancel
from otx.algorithms.common.adapters.nncf import (
    AccuracyAwareLrUpdater,
    check_nncf_is_enabled,
)
from otx.algorithms.common.adapters.nncf.compression import NNCFMetaState

NNCF_META_KEY = "nncf_meta"


# TODO: refactoring
[docs] @RUNNERS.register_module() class AccuracyAwareRunner(EpochRunnerWithCancel): # pylint: disable=too-many-instance-attributes """AccuracyAwareRunner for NNCF task. An mmcv training runner to be used with NNCF-based accuracy-aware training. Inherited from the standard EpochBasedRunner with the overridden "run" method. This runner does not use the "workflow" and "max_epochs" parameters that are used by the EpochBasedRunner since the training is controlled by NNCF's AdaptiveCompressionTrainingLoop that does the scheduling of the compression-aware training loop using the parameters specified in the "accuracy_aware_training". """ def __init__(self, *args, nncf_config, nncf_meta=None, **kwargs): super().__init__(*args, **kwargs) self.nncf_config = nncf_config if nncf_meta is None: nncf_meta = NNCFMetaState() self.nncf_meta = nncf_meta self.compression_ctrl = None self._target_metric_name = nncf_config["target_metric_name"] self._train_data_loader = None self._eval_hook = None
[docs] def run(self, data_loaders, *args, **kwargs): # pylint: disable=unused-argument """run.""" check_nncf_is_enabled() from nncf.common.accuracy_aware_training import ( create_accuracy_aware_training_loop, ) assert isinstance(data_loaders, list) lr_update_hook = [] eval_hook = [] found_compression_hook = False for hook in self.hooks: if isinstance(hook, LrUpdaterHook): lr_update_hook.append(hook) if isinstance(hook, CompressionHook): found_compression_hook = True if isinstance(hook, EvalHook): eval_hook.append(hook) assert found_compression_hook, f"{CompressionHook} must be registered to {self}." assert len(lr_update_hook) <= 1, ( f"More than 1 lr update hooks ({len(lr_update_hook)} " f"are registered to {self}" ) assert len(eval_hook) == 1, f"{EvalHook} must be registered to {self}" self._eval_hook = eval_hook[0] assert self._eval_hook.save_best == self.nncf_config.target_metric_name, ( "'target_metric_name' from nncf_config is not identical to 'save_best' in 'EvalHook'. " f"({self._eval_hook.save_best} != {self.nncf_config.target_metric_name})" ) work_dir = self.work_dir if self.work_dir is not None else "NONE" self.logger.info("Start running, host: %s, work_dir: %s", get_host_info(), work_dir) self.logger.warning( "Note that the workflow and max_epochs parameters are not used in NNCF-based accuracy-aware training" ) # taking only the first data loader for NNCF training self._train_data_loader = data_loaders[0] # Maximum possible number of iterations, needs for progress tracking params = self.nncf_config["accuracy_aware_training"]["params"] self._max_epochs = params["maximal_total_epochs"] self._max_iters = self._max_epochs * len(self._train_data_loader) self.logger.info("Start running, host: %s, work_dir: %s", get_host_info(), work_dir) self.logger.info("Hooks will be executed in the following order:\n%s", self.get_hook_info()) self.call_hook("before_run") def configure_optimizers_fn(): return self.optimizer, None if len(lr_update_hook) == 1: lr_update_hook = lr_update_hook[0] def configure_optimizers_fn(): # noqa: F811 # pylint: disable=function-redefined return self.optimizer, AccuracyAwareLrUpdater(lr_update_hook) # pylint: disable-next=unused-argument def dump_checkpoint_fn(model, compression_ctrl, nncf_runner, save_dir): # pylint: disable-next=protected-access self._eval_hook._save_ckpt(self, nncf_runner.best_val_metric_value) return self._eval_hook.best_ckpt_path if hasattr(self.model, "module"): uncompressed_model_accuracy = self.model.module.nncf._uncompressed_model_accuracy else: uncompressed_model_accuracy = self.model.nncf._uncompressed_model_accuracy acc_aware_training_loop = create_accuracy_aware_training_loop( self.nncf_config, self.compression_ctrl, verbose=False, uncompressed_model_accuracy=uncompressed_model_accuracy, ) model = acc_aware_training_loop.run( self.model, train_epoch_fn=self.train_fn, validate_fn=self.validation_fn, configure_optimizers_fn=configure_optimizers_fn, dump_checkpoint_fn=dump_checkpoint_fn, log_dir=self.work_dir, ) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook("after_run") return model
[docs] def train_fn(self, *args, **kwargs): # pylint: disable=unused-argument """train_fn. Train the model for a single epoch. This method is used in NNCF-based accuracy-aware training. """ self.train(self._train_data_loader)
[docs] def validation_fn(self, *args, **kwargs): # pylint: disable=unused-argument """validation_fn. Return the target metric value on the validation dataset. This method is used in NNCF-based accuracy-aware training. """ # make sure evaluation hook is in a 'should_evaluate' state interval_bak = self._eval_hook.interval self._eval_hook.interval = 1 self._eval_hook._do_evaluate(self) # pylint: disable=protected-access self._eval_hook.interval = interval_bak # Get metric from runner's attributes that set in EvalHook.evaluate() function all_metrics = getattr(self, "all_metrics", {}) metric = all_metrics.get(self._target_metric_name, None) if metric is None: raise RuntimeError(f"Could not find the {self._target_metric_name} key") return metric
[docs] def save_checkpoint(self, *args, **kwargs) -> None: """Save checkpoint with NNCF meta state.""" compression_state = self.compression_ctrl.get_compression_state() for algo_state in compression_state.get("ctrl_state", {}).values(): if not algo_state.get("scheduler_state"): algo_state["scheduler_state"] = {"current_step": 0, "current_epoch": 0} nncf_meta = NNCFMetaState( **{**asdict(self.nncf_meta), "compression_ctrl": compression_state}, ) meta = kwargs.pop("meta", {}) meta[NNCF_META_KEY] = nncf_meta meta["nncf_enable_compression"] = True super().save_checkpoint(*args, **kwargs, meta=meta)