Source code for otx.algorithms.common.adapters.mmcv.hooks.progress_hook

"""Collections of hooks for common OTX algorithms."""

# Copyright (C) 2021-2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import math

from mmcv.runner import BaseRunner
from mmcv.runner.hooks import HOOKS, Hook

from otx.api.usecases.reporting.time_monitor_callback import TimeMonitorCallback
from otx.utils.logger import get_logger

logger = get_logger()


[docs] @HOOKS.register_module() class OTXProgressHook(Hook): """OTXProgressHook for getting progress.""" def __init__(self, time_monitor: TimeMonitorCallback, verbose: bool = False): super().__init__() self.time_monitor = time_monitor self.verbose = verbose self.print_threshold = 1
[docs] def before_run(self, runner: BaseRunner): """Called before_run in OTXProgressHook.""" total_epochs = runner.max_epochs if runner.max_epochs is not None else 1 self.time_monitor.total_epochs = total_epochs self.time_monitor.train_steps = runner.max_iters // total_epochs if total_epochs else 1 self.time_monitor.steps_per_epoch = self.time_monitor.train_steps + self.time_monitor.val_steps self.time_monitor.total_steps = max(math.ceil(self.time_monitor.steps_per_epoch * total_epochs), 1) self.time_monitor.current_step = 0 self.time_monitor.current_epoch = 0 self.time_monitor.on_train_begin()
[docs] def before_epoch(self, runner: BaseRunner): """Called before_epoch in OTXProgressHook.""" self.time_monitor.on_epoch_begin(runner.epoch)
[docs] def after_epoch(self, runner: BaseRunner): """Called after_epoch in OTXProgressHook.""" # put some runner's training status to use on the other hooks runner.log_buffer.output["current_iters"] = runner.iter self.time_monitor.on_epoch_end(runner.epoch, runner.log_buffer.output)
[docs] def before_iter(self, runner: BaseRunner): """Called before_iter in OTXProgressHook.""" self.time_monitor.on_train_batch_begin(1)
[docs] def after_iter(self, runner: BaseRunner): """Called after_iter in OTXProgressHook.""" # put some runner's training status to use on the other hooks runner.log_buffer.output["current_iters"] = runner.iter self.time_monitor.on_train_batch_end(1) if self.verbose: progress = self.progress if progress >= self.print_threshold: logger.info(f"training progress {progress:.0f}%") self.print_threshold = (progress + 10) // 10 * 10
[docs] def before_val_iter(self, runner: BaseRunner): """Called before_val_iter in OTXProgressHook.""" self.time_monitor.on_test_batch_begin(1, logger)
[docs] def after_val_iter(self, runner: BaseRunner): """Called after_val_iter in OTXProgressHook.""" self.time_monitor.on_test_batch_end(1, logger)
[docs] def after_run(self, runner: BaseRunner): """Called after_run in OTXProgressHook.""" self.time_monitor.on_train_end(1) if self.time_monitor.update_progress_callback: self.time_monitor.update_progress_callback(int(self.time_monitor.get_progress()))
@property def progress(self): """Getting Progress from time monitor.""" return self.time_monitor.get_progress()