Source code for otx.algo.plugins.xpu_precision
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Plugin for mixed-precision training on XPU."""
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Generator
import torch
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.optim import LBFGS, Optimizer
if TYPE_CHECKING:
import lightning.pytorch as pl
from lightning_fabric.utilities.types import Optimizable
[docs]
class MixedPrecisionXPUPlugin(Precision):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``.
Args:
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""
def __init__(self, scaler: torch.cuda.amp.GradScaler | None = None) -> None:
self.scaler = scaler
[docs]
def pre_backward(self, tensor: Tensor, module: pl.LightningModule) -> Tensor:
"""Apply grad scaler before backward."""
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
return super().pre_backward(tensor, module)
[docs]
def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
model: pl.LightningModule,
closure: Callable,
**kwargs: dict,
) -> None | dict:
"""Make an optimizer step using scaler if it was passed."""
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(
optimizer,
model=model,
closure=closure,
**kwargs,
)
if isinstance(optimizer, LBFGS):
msg = "Native AMP and the LBFGS optimizer are not compatible."
raise MisconfigurationException(msg)
closure_result = closure()
if not _optimizer_handles_unscaling(optimizer):
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer)
self._after_closure(model, optimizer)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result
[docs]
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: int | float = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Handle grad clipping with scaler."""
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
msg = f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
raise RuntimeError(msg)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
[docs]
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with torch.xpu.autocast(True):
yield
[docs]
def state_dict(self) -> dict[str, Any]:
"""Returns state dict of the plugin."""
if self.scaler is not None:
return self.scaler.state_dict()
return {}
[docs]
def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
"""Loads state dict to the plugin."""
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)
def _optimizer_handles_unscaling(optimizer: torch.optim.Optimizer) -> bool:
"""Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler.
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)