Source code for otx.algorithms.common.adapters.torch.utils.utils

"""Collections of util functions related to torch."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.nn import Module

try:
    from timm.models.layers import convert_sync_batchnorm as timm_cvt_sycnbn
except ImportError:
    timm_cvt_sycnbn = None


[docs] def model_from_timm(model: Module) -> bool: """Check a model comes from timm module. Args: model (Module): model to check it comes from timm module. Returns: bool : whether model comes from timm or not. """ if "timm" in model.__module__.split("."): return True is_fisrt = True for sub_module in model.modules(): if is_fisrt: # First module is the module itself. is_fisrt = False continue if model_from_timm(sub_module): return True return False
[docs] def convert_sync_batchnorm(model: Module): """Convert BatchNorm layers to SyncBatchNorm layers. Args: model (Module): model containing batchnorm layers. """ if timm_cvt_sycnbn is not None and model_from_timm(model): timm_cvt_sycnbn(model) else: torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
[docs] def sync_batchnorm_2_batchnorm(module, dim=2): """Syncs the BatchNorm layers in a model to use regular BatchNorm layers.""" if dim == 1: bn = torch.nn.BatchNorm1d elif dim == 2: bn = torch.nn.BatchNorm2d elif dim == 3: bn = torch.nn.BatchNorm3d else: raise NotImplementedError() module_output = module if isinstance(module, torch.nn.SyncBatchNorm): module_output = bn( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, ) if module.affine: module_output.weight.data = module.weight.data.clone().detach() module_output.bias.data = module.bias.data.clone().detach() module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked if hasattr(module, "qconfig"): module_output.qconfig = module.qconfig for name, child in module.named_children(): module_output.add_module(name, sync_batchnorm_2_batchnorm(child, dim)) del module return module_output