Source code for otx.algorithms.anomaly.adapters.anomalib.plugins.xpu_precision

"""Plugin for mixed-precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Optional, Union

import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.optim import LBFGS, Optimizer


[docs] class MixedPrecisionXPUPlugin(PrecisionPlugin): """Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``. Args: scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """ def __init__(self, scaler: Optional[Any] = None) -> None: self.scaler = scaler
[docs] def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: """Apply grad scaler before backward.""" if self.scaler is not None: tensor = self.scaler.scale(tensor) return super().pre_backward(tensor, module)
[docs] def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, model: "pl.LightningModule", optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: """Make an optimizer step using scaler if it was passed.""" if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler return super().optimizer_step( optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs ) if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) closure_result = closure() if not _optimizer_handles_unscaling(optimizer): # Unscaling needs to be performed here in case we are going to apply gradient clipping. # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam). # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. self.scaler.unscale_(optimizer) self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if not model.automatic_optimization or not skipped_backward: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found step_output = self.scaler.step(optimizer, **kwargs) self.scaler.update() return step_output return closure_result
[docs] def clip_gradients( self, optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """Handle grad clipping with scaler.""" if clip_val > 0 and _optimizer_handles_unscaling(optimizer): raise RuntimeError( f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" ) super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
[docs] @contextmanager def forward_context(self) -> Generator[None, None, None]: """Enable autocast context.""" with torch.xpu.autocast(True): yield
[docs] def state_dict(self) -> Dict[str, Any]: """Returns state dict of the plugin.""" if self.scaler is not None: return self.scaler.state_dict() return {}
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Loads state dict to the plugin.""" if self.scaler is not None: self.scaler.load_state_dict(state_dict)
def _optimizer_handles_unscaling(optimizer: Any) -> bool: """Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler. Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return value will only be reliable for built-in PyTorch optimizers. """ return getattr(optimizer, "_step_supports_amp_scaling", False)