Source code for otx.algorithms.common.utils.dist_utils
"""Module for defining distance utils."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import os
from pathlib import Path
from typing import Union
import torch.distributed as dist
def get_dist_info(): # pylint: disable=inconsistent-return-statements
"""A function that retrieves information about the current distributed training environment."""
if dist.is_available():
# data distributed parallel
try:
return dist.get_rank(), dist.get_world_size(), True
except RuntimeError:
return 0, 1, False
[docs]
def append_dist_rank_suffix(file_name: Union[str, Path]) -> str:
"""Append distributed training rank suffix to the file name."""
if "LOCAL_RANK" in os.environ:
file_name = Path(file_name)
dist_suffix = f"_proc{os.environ['LOCAL_RANK']}"
file_name = file_name.parent / f"{file_name.stem}{dist_suffix}{file_name.suffix}"
return str(file_name)