Source code for otx.algo.plugins.xpu_precision

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Plugin for mixed-precision training on XPU."""

from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Generator

import torch
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.optim import LBFGS, Optimizer

if TYPE_CHECKING:
    import lightning.pytorch as pl
    from lightning_fabric.utilities.types import Optimizable


[docs] class MixedPrecisionXPUPlugin(Precision): """Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``. Args: scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """ def __init__(self, scaler: torch.cuda.amp.GradScaler | None = None) -> None: self.scaler = scaler
[docs] def pre_backward(self, tensor: Tensor, module: pl.LightningModule) -> Tensor: """Apply grad scaler before backward.""" if self.scaler is not None: tensor = self.scaler.scale(tensor) return super().pre_backward(tensor, module)
[docs] def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, model: pl.LightningModule, closure: Callable, **kwargs: dict, ) -> None | dict: """Make an optimizer step using scaler if it was passed.""" if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler return super().optimizer_step( optimizer, model=model, closure=closure, **kwargs, ) if isinstance(optimizer, LBFGS): msg = "Native AMP and the LBFGS optimizer are not compatible." raise MisconfigurationException(msg) closure_result = closure() if not _optimizer_handles_unscaling(optimizer): # Unscaling needs to be performed here in case we are going to apply gradient clipping. # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam). # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. self.scaler.unscale_(optimizer) self._after_closure(model, optimizer) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if not model.automatic_optimization or not skipped_backward: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found step_output = self.scaler.step(optimizer, **kwargs) self.scaler.update() return step_output return closure_result
[docs] def clip_gradients( self, optimizer: Optimizer, clip_val: int | float = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """Handle grad clipping with scaler.""" if clip_val > 0 and _optimizer_handles_unscaling(optimizer): msg = f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" raise RuntimeError(msg) super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
[docs] @contextmanager def forward_context(self) -> Generator[None, None, None]: """Enable autocast context.""" with torch.xpu.autocast(True): yield
[docs] def state_dict(self) -> dict[str, Any]: """Returns state dict of the plugin.""" if self.scaler is not None: return self.scaler.state_dict() return {}
[docs] def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: """Loads state dict to the plugin.""" if self.scaler is not None: self.scaler.load_state_dict(state_dict)
def _optimizer_handles_unscaling(optimizer: torch.optim.Optimizer) -> bool: """Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler. Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return value will only be reliable for built-in PyTorch optimizers. """ return getattr(optimizer, "_step_supports_amp_scaling", False)