Source code for otx.algo.strategies.xpu_single

"""Lightning strategy for single XPU device."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from lightning.pytorch.strategies import StrategyRegistry
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException

from otx.utils.utils import is_xpu_available

    import lightning.pytorch as pl
    from lightning.pytorch.plugins.precision import PrecisionPlugin
    from lightning_fabric.plugins import CheckpointIO
    from lightning_fabric.utilities.types import _DEVICE

[docs] class SingleXPUStrategy(SingleDeviceStrategy): """Strategy for training on single XPU device.""" strategy_name = "xpu_single" def __init__( self, device: _DEVICE = "xpu:0", accelerator: pl.accelerators.Accelerator | None = None, checkpoint_io: CheckpointIO | None = None, precision_plugin: PrecisionPlugin | None = None, ): if not is_xpu_available(): msg = "`SingleXPUStrategy` requires XPU devices to run" raise MisconfigurationException(msg) super().__init__( accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) @property def is_distributed(self) -> bool: """Returns true if the strategy supports distributed training.""" return False
[docs] def setup_optimizers(self, trainer: pl.Trainer) -> None: """Sets up optimizers.""" super().setup_optimizers(trainer) if len(self.optimizers) > 1: # type: ignore[has-type] msg = "XPU strategy doesn't support multiple optimizers" raise RuntimeError(msg) if trainer.task != "SEMANTIC_SEGMENTATION": if len(self.optimizers) == 1: # type: ignore[has-type] model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore[has-type] self.optimizers = [optimizer] self.model = model else: # for inference trainer.model.eval() self.model = torch.xpu.optimize(trainer.model)
StrategyRegistry.register( SingleXPUStrategy.strategy_name, SingleXPUStrategy, description="Strategy that enables training on single XPU", )