Source code for otx.algo.strategies.xpu_single
"""Lightning strategy for single XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import annotations
from typing import TYPE_CHECKING
from lightning.pytorch.strategies import StrategyRegistry
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from otx.utils.device import is_xpu_available
if TYPE_CHECKING:
import lightning.pytorch as pl
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.utilities.types import _DEVICE
[docs]
class SingleXPUStrategy(SingleDeviceStrategy):
"""Strategy for training on single XPU device."""
strategy_name = "xpu_single"
def __init__(
self,
device: _DEVICE = "xpu:0",
accelerator: pl.accelerators.Accelerator | None = None,
checkpoint_io: CheckpointIO | None = None,
precision_plugin: PrecisionPlugin | None = None,
):
if not is_xpu_available():
msg = "`SingleXPUStrategy` requires XPU devices to run"
raise MisconfigurationException(msg)
super().__init__(
accelerator=accelerator,
device=device,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
StrategyRegistry.register(
SingleXPUStrategy.strategy_name,
SingleXPUStrategy,
description="Strategy that enables training on single XPU",
)