Source code for otx.algorithms.common.adapters.mmcv.hooks.checkpoint_hook

"""CheckpointHook with validation results for classification task."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# Copyright (c) Open-MMLab. All rights reserved.
from pathlib import Path
from typing import Optional

from mmcv.runner import BaseRunner
from mmcv.runner.dist_utils import allreduce_params, master_only
from mmcv.runner.hooks.hook import HOOKS, Hook


[docs] @HOOKS.register_module() class CheckpointHookWithValResults(Hook): # pylint: disable=too-many-instance-attributes """Save checkpoints periodically. Args: interval (int): The saving period. If ``by_epoch=True``, interval indicates epochs, otherwise it indicates iterations. Default: -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Default: True. save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True. out_dir (str, optional): The directory to save checkpoints. If not specified, ``runner.work_dir`` will be used by default. max_keep_ckpts (int, optional): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Default: -1, which means unlimited. sync_buffer (bool): Whether to synchronize buffers in different gpus. Default: False. """ def __init__( self, interval=-1, by_epoch=True, save_optimizer=True, out_dir=None, max_keep_ckpts=-1, sync_buffer=False, **kwargs, ) -> None: self.interval = interval self.by_epoch = by_epoch self.save_optimizer = save_optimizer self.out_dir = out_dir self.max_keep_ckpts = max_keep_ckpts self.args = kwargs self.sync_buffer = sync_buffer self._best_model_weight: Optional[Path] = None
[docs] def before_run(self, runner): """Set output directopy if not set.""" if not self.out_dir: self.out_dir = runner.work_dir
[docs] def after_train_epoch(self, runner): """Checkpoint stuffs after train epoch.""" if not self.by_epoch or not self.every_n_epochs(runner, self.interval): return if self.sync_buffer: allreduce_params(runner.model.buffers()) save_ema_model = hasattr(runner, "save_ema_model") and runner.save_ema_model if save_ema_model: backup_model = runner.model runner.model = runner.ema_model if getattr(runner, "save_ckpt", False): runner.logger.info(f"Saving best checkpoint at {runner.epoch + 1} epochs") self._save_best_checkpoint(runner) runner.save_ckpt = False self._save_latest_checkpoint(runner) if save_ema_model: runner.model = backup_model runner.save_ema_model = False
@master_only def _save_best_checkpoint(self, runner): """Save the current checkpoint and delete unwanted checkpoint.""" if self._best_model_weight is not None: # remove previous best model weight prev_model_weight = self.out_dir / self._best_model_weight if prev_model_weight.exists(): prev_model_weight.unlink() if self.by_epoch: weight_name = f"best_epoch_{runner.epoch + 1}.pth" else: weight_name = f"best_iter_{runner.iter + 1}.pth" runner.save_checkpoint(self.out_dir, filename_tmpl=weight_name, save_optimizer=self.save_optimizer, **self.args) self._best_model_weight = Path(weight_name) if runner.meta is not None: runner.meta.setdefault("hook_msgs", dict()) runner.meta["hook_msgs"]["best_ckpt"] = str(self.out_dir / self._best_model_weight) @master_only def _save_latest_checkpoint(self, runner): """Save the current checkpoint and delete unwanted checkpoint.""" if self.by_epoch: weight_name_format = "epoch_{}.pth" cur_step = runner.epoch + 1 else: weight_name_format = "iter_{}.pth" cur_step = runner.iter + 1 runner.save_checkpoint( self.out_dir, filename_tmpl=weight_name_format.format(cur_step), save_optimizer=self.save_optimizer, **self.args, ) # remove other checkpoints if self.max_keep_ckpts > 0: for _step in range(cur_step - self.max_keep_ckpts * self.interval, 0, -self.interval): ckpt_path = self.out_dir / Path(weight_name_format.format(_step)) if ckpt_path.exists(): ckpt_path.unlink() if runner.meta is not None: cur_ckpt_filename = Path(self.args.get("filename_tmpl", weight_name_format.format(cur_step))) runner.meta.setdefault("hook_msgs", dict()) runner.meta["hook_msgs"]["last_ckpt"] = str(self.out_dir / cur_ckpt_filename)
[docs] def after_train_iter(self, runner): """Checkpoint stuffs after train iteration.""" if self.by_epoch or not self.every_n_iters(runner, self.interval): return if hasattr(runner, "save_ckpt"): if runner.save_ckpt: runner.logger.info(f"Saving checkpoint at {runner.iter + 1} iterations") if self.sync_buffer: allreduce_params(runner.model.buffers()) self._save_checkpoint(runner) runner.save_ckpt = False
[docs] @HOOKS.register_module() class EnsureCorrectBestCheckpointHook(Hook): """EnsureCorrectBestCheckpointHook. This hook makes sure that the 'best_mAP' checkpoint points properly to the best model, even if the best model is created in the last epoch. """
[docs] def after_run(self, runner: BaseRunner): """Called after train epoch hooks.""" runner.call_hook("after_train_epoch")
@HOOKS.register_module() class SaveInitialWeightHook(Hook): """Save the initial weights before training.""" def __init__(self, save_path, file_name: str = "weights.pth", **kwargs): self._save_path = save_path self._file_name = file_name self._args = kwargs def before_run(self, runner): """Save initial the weights before training.""" runner.logger.info("Saving weight before training") runner.save_checkpoint( self._save_path, filename_tmpl=self._file_name, save_optimizer=False, create_symlink=False, **self._args )