Source code for otx.api.usecases.reporting.time_monitor_callback
"""Time monitor callback module."""
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# pylint: disable=too-many-instance-attributes,too-many-arguments
import math
import time
import logging
from copy import deepcopy
from typing import List
import dill # nosec B403 used dill.pickles only to pickle callback object creating internally
from otx.api.entities.train_parameters import (
UpdateProgressCallback,
default_progress_callback,
)
from otx.api.usecases.reporting.callback import Callback
logger = logging.getLogger(__name__)
[docs]
class TimeMonitorCallback(Callback):
"""A callback to monitor the progress of training.
Args:
num_epoch (int): Amount of epochs
num_train_steps (int): amount of training steps per epoch
num_val_steps (int): amount of validation steps per epoch
num_test_steps (int): amount of testing steps
epoch_history (int): Amount of previous epochs to calculate average epoch time over
step_history (int): Amount of previous steps to calculate average steps time over
update_progress_callback (UpdateProgressCallback): Callback to update progress
"""
def __init__(
self,
num_epoch: int = 0,
num_train_steps: int = 0,
num_val_steps: int = 0,
num_test_steps: int = 0,
epoch_history: int = 5,
step_history: int = 50,
update_progress_callback: UpdateProgressCallback = default_progress_callback,
):
self.total_epochs = num_epoch
self.train_steps = num_train_steps
self.val_steps = num_val_steps
self.test_steps = num_test_steps
self.steps_per_epoch = self.train_steps + self.val_steps
self.total_steps = math.ceil(self.steps_per_epoch * self.total_epochs + num_test_steps)
self.current_step = 0
self.current_epoch = 0
# Step time calculation
self.start_step_time = time.time()
self.past_step_duration: List[float] = []
self.average_step = 0
self.step_history = step_history
# epoch time calculation
self.start_epoch_time = time.time()
self.past_epoch_duration: List[float] = []
self.average_epoch = 0
self.epoch_history = epoch_history
# whether model is training flag
self.is_training = False
self.update_progress_callback = update_progress_callback
def __getstate__(self):
"""Return state values to be pickled."""
state = self.__dict__.copy()
# update_progress_callback is not always pickable object
# if it is not, replace it with default callback
if not dill.pickles(state["update_progress_callback"]):
state["update_progress_callback"] = default_progress_callback
return state
def __deepcopy__(self, memo):
"""Return deepcopy object."""
update_progress_callback = self.update_progress_callback
self.update_progress_callback = None
self.__dict__["__deepcopy__"] = None
result = deepcopy(self, memo)
self.__dict__.pop("__deepcopy__")
result.__dict__.pop("__deepcopy__")
result.update_progress_callback = update_progress_callback
self.update_progress_callback = update_progress_callback
memo[id(self)] = result
return result
[docs]
def on_train_batch_begin(self, batch, logs=None):
"""Set the value of current step and start the timer."""
self.current_step += 1
self.start_step_time = time.time()
[docs]
def on_train_batch_end(self, batch, logs=None):
"""Compute average time taken to complete a step."""
self.__calculate_average_step()
[docs]
def is_stalling(self) -> bool:
"""Returns True if the training is stalling.
Returns True if the current step has taken more than 30 seconds and
at least 20x more than the average step duration
"""
factor = 20
min_abs_threshold = 30 # seconds
if self.is_training and self.current_step > 2:
step_duration = time.time() - self.start_step_time
if step_duration > min_abs_threshold and step_duration > factor * self.average_step:
logger.error(
f"Step {self.current_step} has taken {step_duration}s which is "
f">{min_abs_threshold}s and {factor} times "
f"more than the expected {self.average_step}s"
)
return True
return False
def __calculate_average_step(self):
"""Compute average duration taken to complete a step."""
self.past_step_duration.append(time.time() - self.start_step_time)
if len(self.past_step_duration) > self.step_history:
self.past_step_duration.remove(self.past_step_duration[0])
self.average_step = sum(self.past_step_duration) / len(self.past_step_duration)
[docs]
def on_test_batch_begin(self, batch, logs):
"""Set the number of current epoch and start the timer."""
self.current_step += 1
self.start_step_time = time.time()
[docs]
def on_test_batch_end(self, batch, logs):
"""Compute average time taken to complete a step based on a running average of `step_history` steps."""
self.__calculate_average_step()
[docs]
def on_train_begin(self, logs=None):
"""Sets training to true."""
self.is_training = True
[docs]
def on_train_end(self, logs=None):
"""Handles early stopping when the total_steps is greater than the current_step."""
# To handle cases where early stopping stops the task the progress will still be accurate
self.current_step = self.total_steps - self.test_steps
self.current_epoch = self.total_epochs
self.is_training = False
[docs]
def on_epoch_begin(self, epoch, logs=None):
"""Set the number of current epoch and start the timer."""
self.current_epoch = epoch + 1
self.start_epoch_time = time.time()
[docs]
def on_epoch_end(self, epoch, logs=None):
"""Computes the average time taken to complete an epoch based on a running average of `epoch_history` epochs."""
self.past_epoch_duration.append(time.time() - self.start_epoch_time)
self._calculate_average_epoch()
self.update_progress_callback(self.get_progress())
def _calculate_average_epoch(self):
if len(self.past_epoch_duration) > self.epoch_history:
del self.past_epoch_duration[0]
self.average_epoch = sum(self.past_epoch_duration) / len(self.past_epoch_duration)
[docs]
def get_progress(self):
"""Returns current progress as a percentage."""
return (self.current_step / self.total_steps) * 100