Source code for otx.algo.callbacks.unlabeled_loss_warmup
"""Module for defining hook for semi-supervised learning for classification task."""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import math
from typing import Any
from lightning import Callback, LightningModule, Trainer
[docs]
class UnlabeledLossWarmUpCallback(Callback):
"""Hook for SemiSL for classification.
This hook includes unlabeled warm-up loss coefficient (default: True):
unlabeled_coef = (0.5 - cos(min(pi, 2 * pi * k) / K)) / 2
k: current step, K: total steps
Args:
warmup_steps_ratio (float): Ratio of warm-up steps to total steps (default: 0.2).
"""
def __init__(self, warmup_steps_ratio: float = 0.2):
self.warmup_steps_ratio = warmup_steps_ratio
self.total_steps = 0
self.current_step, self.unlabeled_coef = 0, 0.0
self.num_pseudo_label = 0
[docs]
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any, # noqa: ANN401
batch_idx: int,
) -> None:
"""Calculate the unlabeled warm-up loss coefficient before training iteration."""
if self.unlabeled_coef < 1.0:
if self.total_steps == 0:
dataloader = (
trainer.train_dataloader["labeled"]
if isinstance(trainer.train_dataloader, dict)
else trainer.train_dataloader
)
self.total_steps = int(trainer.max_epochs * len(dataloader) * self.warmup_steps_ratio)
self.unlabeled_coef = 0.5 * (
1 - math.cos(min(math.pi, (2 * math.pi * self.current_step) / self.total_steps))
)
if trainer.model is None:
msg = "Model is not found in the trainer."
raise ValueError(msg)
trainer.model.model.unlabeled_coef = self.unlabeled_coef
self.current_step += 1