Source code for otx.algorithms.common.adapters.mmcv.runner

"""Runner with cancel for common OTX algorithms."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

# Is based on
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/epoch_based_runner.py
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py

import time
import warnings
from typing import List, Optional, Sequence

import mmcv
import torch.distributed as dist
from mmcv.runner import (
    RUNNERS,
    EpochBasedRunner,
    IterBasedRunner,
    IterLoader,
    get_dist_info,
)
from mmcv.runner.utils import get_host_info
from torch.utils.data.dataloader import DataLoader


# pylint: disable=too-many-instance-attributes, attribute-defined-outside-init
[docs] @RUNNERS.register_module() class EpochRunnerWithCancel(EpochBasedRunner): """Simple modification to EpochBasedRunner to allow cancelling the training during an epoch. A stopping hook should set the runner.should_stop flag to True if stopping is required. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.should_stop = False _, world_size = get_dist_info() self.distributed = world_size > 1 self.save_ema_model = False
[docs] def stop(self) -> bool: """Returning a boolean to break the training loop. This method supports distributed training by broadcasting should_stop to other ranks :return: a cancellation bool """ broadcast_obj = [False] if self.rank == 0 and self.should_stop: broadcast_obj = [True] if self.distributed: dist.broadcast_object_list(broadcast_obj, src=0) if broadcast_obj[0]: self._max_epochs = self.epoch return broadcast_obj[0]
[docs] def train(self, data_loader: DataLoader, **kwargs): """Train call hook.""" self.model.train() self.mode = "train" self.data_loader = data_loader self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook("before_train_epoch") if self.distributed: time.sleep(2) # Prevent possible multi-gpu deadlock during epoch transition for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook("before_train_iter") self.run_iter(data_batch, train_mode=True, **kwargs) self.call_hook("after_train_iter") if self.stop(): break self._iter += 1 self.save_ema_model = False # revert ema status before new iter self.call_hook("after_train_epoch") self.stop() self._epoch += 1
[docs] @RUNNERS.register_module() class IterBasedRunnerWithCancel(IterBasedRunner): """Runner With Cancel for early-stopping (Iter based). Simple modification to IterBasedRunner to allow cancelling the training. The cancel training hook should set the runner.should_stop flag to True if stopping is required. # TODO: Implement cancelling of training via keyboard interrupt signal, instead of should_stop """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.should_stop = False
[docs] def main_loop(self, workflow: List[tuple], iter_loaders: Sequence[IterLoader], **kwargs): """Main loop function in IterBasedRunnerWithCancel.""" while self.iter < self._max_iters: for i, flow in enumerate(workflow): self._inner_iter = 0 mode, iters = flow if not isinstance(mode, str) or not hasattr(self, mode): raise ValueError(f'runner has no method named "{mode}" to run a workflow') iter_runner = getattr(self, mode) for _ in range(iters): if mode == "train" and self.iter >= self._max_iters: break iter_runner(iter_loaders[i], **kwargs) if self.should_stop: return
[docs] def run(self, data_loaders: Sequence[DataLoader], workflow: List[tuple], max_iters: Optional[int] = None, **kwargs): """Function of main run.""" assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) if max_iters is not None: warnings.warn( "setting max_iters in run is deprecated, please set max_iters in runner_config", DeprecationWarning, ) self._max_iters = max_iters assert self._max_iters is not None, "max_iters must be specified during instantiation" work_dir = self.work_dir if self.work_dir is not None else "NONE" self.logger.info("Start running, host: %s, work_dir: %s", get_host_info(), work_dir) self.logger.info("workflow: %s, max: %d iters", workflow, self._max_iters) self.call_hook("before_run") iter_loaders = [IterLoader(x) for x in data_loaders] self.call_hook("before_epoch") self.should_stop = False self.main_loop(workflow, iter_loaders, **kwargs) self.should_stop = False # time.sleep(1) # wait for some hooks like loggers to finish self.call_hook("after_epoch") self.call_hook("after_run")