Source code for otx.algo.callbacks.gpu_mem_monitor

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Monitor GPU memory hook."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from lightning.pytorch.callbacks.callback import Callback

if TYPE_CHECKING:
    from lightning import LightningModule, Trainer


[docs] class GPUMemMonitor(Callback): """Monitor GPU memory hook.""" def _get_and_log_device_stats( self, trainer: Trainer, pl_module: LightningModule, ) -> None: """Get and log current GPU memory usage. Args: trainer (Trainer): pl trainer. pl_module (LightningModule): pl module. batch_size (int): batch size. """ device = trainer.strategy.root_device if device.type in ["cpu", "xpu"]: return device_stats = trainer.accelerator.get_device_stats(device) allocated = device_stats["allocated_bytes.all.current"] reserved = device_stats["reserved_bytes.all.current"] used_memory = (allocated + reserved) / 1024**3 # convert to GiB used_memory = round(used_memory, 2) pl_module.log( name="gpu_mem", value=used_memory, prog_bar=True, on_step=True, on_epoch=False, )
[docs] def on_train_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, # noqa: ANN401 batch_idx: int, ) -> None: """Log GPU memory usage at the start of every train batch. Args: trainer (Trainer): pl trainer. pl_module (LightningModule): pl module. batch (Any): current batch. batch_idx (int): current batch index. """ self._get_and_log_device_stats( trainer, pl_module, )
[docs] def on_validation_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, # noqa: ANN401 batch_idx: int, dataloader_idx: int = 0, ) -> None: """Log GPU memory usage at the start of every validation batch. Args: trainer (Trainer): pl trainer. pl_module (LightningModule): pl module. batch (Any): current batch. batch_idx (int): current batch index. dataloader_idx (int, optional): dataloader index. Defaults to 0. """ self._get_and_log_device_stats( trainer, pl_module, )