Source code for otx.algorithms.common.utils.callback

"""Collection of callback utils to run common OTX algorithms."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import time

from otx.api.usecases.reporting.time_monitor_callback import TimeMonitorCallback


[docs] class TrainingProgressCallback(TimeMonitorCallback): """TrainingProgressCallback class for time monitoring.""" def __init__(self, update_progress_callback, **kwargs): super().__init__(update_progress_callback=update_progress_callback, **kwargs)
[docs] def on_train_batch_end(self, batch, logs=None): """Callback function on training batch ended.""" super().on_train_batch_end(batch, logs) self.update_progress_callback(self.get_progress())
[docs] def on_epoch_end(self, epoch, logs=None): """Callback function on epoch ended.""" self.past_epoch_duration.append(time.time() - self.start_epoch_time) progress = ((epoch + 1) / self.total_epochs) * 100 self._calculate_average_epoch() score = None if hasattr(self.update_progress_callback, "metric") and isinstance(logs, dict): score = logs.get(self.update_progress_callback.metric, None) self.update_progress_callback(progress, score=score)
[docs] class InferenceProgressCallback(TimeMonitorCallback): """InferenceProgressCallback class for time monitoring.""" def __init__(self, num_test_steps, update_progress_callback, **kwargs): super().__init__( num_epoch=0, num_train_steps=0, num_val_steps=0, num_test_steps=num_test_steps, update_progress_callback=update_progress_callback, **kwargs, )
[docs] def on_test_batch_end(self, batch=None, logs=None): """Callback function on testing batch ended.""" super().on_test_batch_end(batch, logs) self.update_progress_callback(int(self.get_progress()))
[docs] class OptimizationProgressCallback(TrainingProgressCallback): """Progress callback used for optimization using NNCF. There are three stages to the progress bar: - 5 % model is loaded - 10 % compressed model is initialized - 10-100 % compressed model is being fine-tuned """ def __init__( self, update_progress_callback, loading_stage_progress_percentage: int = 5, initialization_stage_progress_percentage: int = 5, **kwargs, ): super().__init__(update_progress_callback=update_progress_callback, **kwargs) if loading_stage_progress_percentage + initialization_stage_progress_percentage >= 100: raise RuntimeError("Total optimization progress percentage is more than 100%") self.loading_stage_progress_percentage = loading_stage_progress_percentage self.initialization_stage_progress_percentage = initialization_stage_progress_percentage # set loading_stage_progress_percentage from the start as the model is already loaded at this point if self.update_progress_callback: self.update_progress_callback(loading_stage_progress_percentage)
[docs] def on_train_begin(self, logs=None): """Callback function when training beginning.""" super().on_train_begin(logs) # Callback initialization takes place here after OTXProgressHook.before_run() is called train_percentage = 100 - self.loading_stage_progress_percentage - self.initialization_stage_progress_percentage loading_stage_steps = self.total_steps * self.loading_stage_progress_percentage / train_percentage initialization_stage_steps = self.total_steps * self.initialization_stage_progress_percentage / train_percentage self.total_steps += loading_stage_steps + initialization_stage_steps self.current_step = loading_stage_steps + initialization_stage_steps self.update_progress_callback(self.get_progress())
[docs] def on_train_end(self, logs=None): """Callback function on training ended.""" super().on_train_end(logs) self.update_progress_callback(self.get_progress(), score=logs)
[docs] def on_initialization_end(self): """on_initialization_end callback for optimization using NNCF.""" self.update_progress_callback( self.loading_stage_progress_percentage + self.initialization_stage_progress_percentage )