Source code for otx.algo.callbacks.ema_mean_teacher

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Module for exponential moving average for SemiSL mean teacher algorithm."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from lightning import Callback, LightningModule, Trainer

    from lightning.pytorch.utilities.types import STEP_OUTPUT

[docs] class EMAMeanTeacher(Callback): """callback for SemiSL MeanTeacher algorithm. This callback averages the weights of the teacher model. Args: momentum (float, optional): momentum. Defaults to 0.999. start_epoch (int, optional): start epoch. Defaults to 1. """ def __init__( self, momentum: float = 0.999, start_epoch: int = 1, ) -> None: super().__init__() self.momentum = momentum self.start_epoch = start_epoch self.synced_models = False
[docs] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Set up src & dst model parameters.""" # call to nn.model model = trainer.model.model self.src_model = getattr(model, "student_model", None) self.dst_model = getattr(model, "teacher_model", None) if self.src_model is None or self.dst_model is None: msg = "student_model and teacher_model should be set for MeanTeacher algorithm" raise RuntimeError(msg) self.src_params = self.src_model.state_dict(keep_vars=True) self.dst_params = self.dst_model.state_dict(keep_vars=True)
[docs] def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, # noqa: ANN401 batch_idx: int, ) -> None: """Update ema parameter every iteration.""" if trainer.current_epoch < self.start_epoch: return # EMA self._ema_model(trainer.global_step)
def _copy_model(self) -> None: with torch.no_grad(): for name, src_param in self.src_params.items(): if src_param.requires_grad: dst_param = self.dst_params[name] def _ema_model(self, global_step: int) -> None: if self.start_epoch != 0 and not self.synced_models: self._copy_model() self.synced_models = True momentum = min(1 - 1 / (global_step + 1), self.momentum) with torch.no_grad(): for name, src_param in self.src_params.items(): if src_param.requires_grad: dst_param = self.dst_params[name] * momentum + * (1 - momentum))