Source code for otx.algorithms.anomaly.adapters.anomalib.strategies.xpu_single
"""Lightning strategy for single XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import Optional
import pytorch_lightning as pl
import torch
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies import StrategyRegistry
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from otx.algorithms.common.utils.utils import is_xpu_available
[docs]
class SingleXPUStrategy(SingleDeviceStrategy):
"""Strategy for training on single XPU device."""
strategy_name = "xpu_single"
def __init__(
self,
device: _DEVICE = "xpu:0",
accelerator: Optional["pl.accelerators.Accelerator"] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
if not is_xpu_available():
raise MisconfigurationException("`SingleXPUStrategy` requires XPU devices to run")
super().__init__(
accelerator=accelerator,
device=device,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
@property
def is_distributed(self) -> bool:
"""Returns true if the strategy supports distributed training."""
return False
[docs]
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""Sets up optimizers."""
super().setup_optimizers(trainer)
if len(self.optimizers) != 1: # type: ignore
raise RuntimeError("XPU strategy doesn't support multiple optimizers")
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore
self.optimizers = [optimizer]
trainer.model = model
StrategyRegistry.register(
SingleXPUStrategy.strategy_name, SingleXPUStrategy, description="Strategy that enables training on single XPU"
)