Source code for otx.algo.accelerators.xpu

"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import annotations

from typing import Any

import torch
from lightning.pytorch.accelerators import AcceleratorRegistry
from lightning.pytorch.accelerators.accelerator import Accelerator

from otx.utils.utils import is_xpu_available


[docs] class XPUAccelerator(Accelerator): """Support for a XPU, optimized for large-scale machine learning.""" accelerator_name = "xpu"
[docs] def setup_device(self, device: torch.device) -> None: """Sets up the specified device.""" if device.type != "xpu": msg = f"Device should be xpu, got {device} instead" raise RuntimeError(msg) torch.xpu.set_device(device)
[docs] @staticmethod def parse_devices(devices: str | list | torch.device) -> list: """Parses devices for multi-GPU training.""" if isinstance(devices, list): return devices return [devices]
[docs] @staticmethod def get_parallel_devices(devices: list) -> list[torch.device]: """Generates a list of parrallel devices.""" return [torch.device("xpu", idx) for idx in devices]
[docs] @staticmethod def auto_device_count() -> int: """Returns number of XPU devices available.""" return torch.xpu.device_count()
[docs] @staticmethod def is_available() -> bool: """Checks if XPU available.""" return is_xpu_available()
[docs] def get_device_stats(self, device: str | torch.device) -> dict[str, Any]: """Returns XPU devices stats.""" return {}
[docs] def teardown(self) -> None: """Clean up any state created by the accelerator."""
AcceleratorRegistry.register( XPUAccelerator.accelerator_name, XPUAccelerator, description="Accelerator supports XPU devices", )