Source code for otx.algorithms.anomaly.adapters.anomalib.accelerators.xpu

"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import Any, Dict, Union

import torch
from pytorch_lightning.accelerators import AcceleratorRegistry
from pytorch_lightning.accelerators.accelerator import Accelerator

from otx.algorithms.common.utils.utils import is_xpu_available


[docs] class XPUAccelerator(Accelerator): """Support for a XPU, optimized for large-scale machine learning.""" accelerator_name = "xpu"
[docs] def setup_device(self, device: torch.device) -> None: """Sets up the specified device.""" if device.type != "xpu": raise RuntimeError(f"Device should be xpu, got {device} instead") torch.xpu.set_device(device)
[docs] @staticmethod def parse_devices(devices: Any) -> Any: """Parses devices for multi-GPU training.""" if isinstance(devices, list): return devices return [devices]
[docs] @staticmethod def get_parallel_devices(devices: Any) -> Any: """Generates a list of parrallel devices.""" return [torch.device("xpu", idx) for idx in devices]
[docs] @staticmethod def auto_device_count() -> int: """Returns number of XPU devices available.""" return torch.xpu.device_count()
[docs] @staticmethod def is_available() -> bool: """Checks if XPU available.""" return is_xpu_available()
[docs] def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Returns XPU devices stats.""" return {}
[docs] def teardown(self) -> None: """Cleans-up XPU-related resources.""" pass
AcceleratorRegistry.register( XPUAccelerator.accelerator_name, XPUAccelerator, description="Accelerator supports XPU devices" )