Source code for otx.algorithms.common.adapters.mmcv.hooks.logger_hook

"""Logger hooks."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from typing import Any, Dict, Optional

from mmcv.runner import BaseRunner
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks import HOOKS, Hook, LoggerHook

from otx.utils.logger import get_logger

logger = get_logger()


[docs] @HOOKS.register_module() class OTXLoggerHook(LoggerHook): """OTXLoggerHook for Logging."""
[docs] class Curve: """Curve with x (epochs) & y (scores).""" def __init__(self): self.x = [] self.y = [] def __repr__(self): """Repr function.""" points = [] for x, y in zip(self.x, self.y): points.append(f"({x},{y})") return "curve[" + ",".join(points) + "]"
_TAGS_TO_SKIP = ( "accuracy_top-1", "current_iters", "decode.acc_seg", "decode.loss_ce_ignore", ) _TAGS_TO_RENAME = { "train/time": "train/time (sec/iter)", "train/data_time": "train/data_time (sec/iter)", "val/accuracy": "val/accuracy (%)", } def __init__( self, curves: Optional[Dict[Any, Curve]] = None, interval: int = 10, ignore_last: bool = True, reset_flag: bool = True, by_epoch: bool = True, ): super().__init__(interval, ignore_last, reset_flag, by_epoch) self.curves = curves if curves is not None else defaultdict(self.Curve)
[docs] @master_only def log(self, runner: BaseRunner): """Log function for OTXLoggerHook.""" tags = self.get_loggable_tags(runner, allow_text=False, tags_to_skip=self._TAGS_TO_SKIP) if runner.max_epochs is not None: normalized_iter = self.get_iter(runner) / runner.max_iters * runner.max_epochs else: normalized_iter = self.get_iter(runner) for tag, value in tags.items(): tag = self._TAGS_TO_RENAME.get(tag, tag) curve = self.curves[tag] # Remove duplicates. if len(curve.x) > 0 and curve.x[-1] == normalized_iter: curve.x.pop() curve.y.pop() curve.x.append(normalized_iter) curve.y.append(value)
[docs] def before_run(self, runner: BaseRunner): """Called before_run in OTXLoggerHook.""" super().before_run(runner) self.curves.clear()
[docs] def after_train_epoch(self, runner: BaseRunner): """Called after_train_epoch in OTXLoggerHook.""" # Iteration counter is increased right after the last iteration in the epoch, # temporarily decrease it back. runner._iter -= 1 super().after_train_epoch(runner) runner._iter += 1
@HOOKS.register_module() class LoggerReplaceHook(Hook): """replace logger in the runner to the OTX logger. DO NOT INCLUDE this hook to the recipe directly. OTX will add this hook to all recipe internally. """ def __init__(self, **kwargs): super().__init__(**kwargs) def before_run(self, runner): """Replace logger.""" runner.logger = logger logger.info("logger in the runner is replaced to the OTX logger")