Source code for otx.algorithms.segmentation.adapters.mmseg.models.schedulers.step
"""Step scheduler."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import List
import numpy as np
from otx.algorithms.segmentation.adapters.mmseg.utils.builder import SCALAR_SCHEDULERS
from .base import BaseScalarScheduler
[docs]
@SCALAR_SCHEDULERS.register_module()
class StepScalarScheduler(BaseScalarScheduler):
"""Step learning rate scheduler.
Example:
>>> scheduler = StepScalarScheduler(scales=[1.0, 0.1, 0.01], num_iters=[100, 200])
This means that the learning rate will be 1.0 for the first 100 iterations,
0.1 for the next 200 iterations, and 0.01 for the rest of the iterations.
Args:
scales (List[int]): List of learning rate scales.
num_iters (List[int]): A list specifying the count of iterations at each scale.
by_epoch (bool): Whether to use epoch as the unit of iteration.
"""
def __init__(self, scales: List[float], num_iters: List[int], by_epoch: bool = False):
super().__init__()
self.by_epoch = by_epoch
assert len(scales) == len(num_iters) + 1
assert len(scales) > 0
self._scales = list(scales)
self._iter_ranges = list(num_iters) + [np.iinfo(np.int32).max]
def _get_value(self, step, epoch_size) -> float:
if step is None:
return float(self._scales[-1])
out_scale_idx = 0
for iter_range in self._iter_ranges:
if self.by_epoch:
iter_threshold = epoch_size * iter_range
else:
iter_threshold = iter_range
if step < iter_threshold:
break
out_scale_idx += 1
return float(self._scales[out_scale_idx])