Source code for otx.algorithms.common.adapters.mmcv.nncf.hooks

"""NNCF task related hooks."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmcv.runner.hooks.hook import HOOKS, Hook


[docs] @HOOKS.register_module() class CompressionHook(Hook): """CompressionHook.""" COMPRESSION_STATE_FILE_NAME = "meta_state.pth" def __init__(self, compression_ctrl=None): self.compression_ctrl = compression_ctrl
[docs] def after_train_iter(self, runner): """Called after train iter.""" self.compression_ctrl.scheduler.step()
[docs] def after_train_epoch(self, runner): """Called after train epoch.""" self.compression_ctrl.scheduler.epoch_step() if runner.rank == 0: runner.logger.info(self.compression_ctrl.statistics().to_str())
[docs] def before_run(self, runner): """Called before run.""" runner.compression_ctrl = self.compression_ctrl if runner.rank == 0: runner.logger.info(self.compression_ctrl.statistics().to_str())