# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""This implementation replaces the functionality of mmcv.cnn.bricks.norm.build_norm_layer."""
from __future__ import annotations
import inspect
from functools import partial
from typing import Any, Callable
import torch
from torch import nn
from torch.nn import SyncBatchNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
[docs]
class FrozenBatchNorm2d(nn.Module):
"""Copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py.
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, num_features: int, eps: float = 1e-5) -> None:
"""Initialize FrozenBatchNorm2d.
Args:
num_features (int): Number of input features.
eps (float): Epsilon for batch norm. Defaults to 1e-5.
"""
super().__init__()
n = num_features
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
self.eps = eps
self.num_features = n
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
AVAILABLE_NORMALIZATION_LIST = [
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
SyncBatchNorm,
nn.GroupNorm,
nn.LayerNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
FrozenBatchNorm2d,
]
def infer_abbr(class_type: type) -> str:
"""Infer abbreviation from the class name.
When we build a norm layer with `build_norm_layer()`, we want to preserve
the norm type in variable names, e.g, self.bn1, self.gn. This method will
infer the abbreviation to map class types to abbreviations.
Rule 1: If the class has the property "_abbr_", return the property.
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
"in" respectively.
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
respectively.
Rule 4: Otherwise, the abbreviation falls back to "norm".
Args:
class_type (type): The norm layer type.
Returns:
str: The inferred abbreviation.
"""
if not inspect.isclass(class_type):
msg = f"class_type must be a type, but got {type(class_type)}"
raise TypeError(msg)
if hasattr(class_type, "_abbr_"):
return class_type._abbr_ # noqa: SLF001
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
return "in"
if issubclass(class_type, _BatchNorm):
return "bn"
if issubclass(class_type, nn.GroupNorm):
return "gn"
if issubclass(class_type, nn.LayerNorm):
return "ln"
class_name = class_type.__name__.lower()
if "batch" in class_name:
return "bn"
if "group" in class_name:
return "gn"
if "layer" in class_name:
return "ln"
if "instance" in class_name:
return "in"
return "norm"
def _get_norm_type(normalization: Callable[..., nn.Module]) -> type:
"""Get class type or name of given normalization callable.
Args:
normalization (Callable[..., nn.Module]): Normalization layer module.
Returns:
(type): Class type of given normalization callable.
"""
return normalization.func if isinstance(normalization, partial) else normalization # type: ignore[return-value]
[docs]
def build_norm_layer(
normalization: Callable[..., nn.Module] | tuple[str, nn.Module] | nn.Module | None,
num_features: int,
postfix: int | str = "",
layer_name: str | None = None,
requires_grad: bool = True,
eps: float = 1e-5,
**kwargs,
) -> tuple[str, nn.Module]:
"""Build normalization layer.
Args:
normalization (Callable[..., nn.Module] | tuple[str, nn.Module] | nn.Module | None):
Normalization layer module. If tuple or None is given, return it as is.
If nn.Module is given, return it with empty name string. If callable is given, create the layer.
num_features (int): Number of input channels.
postfix (int | str): The postfix to be appended into norm abbreviation
to create named layer.
layer_name (str | None): The name of the layer.
Defaults to None.
requires_grad (bool): Whether stop gradient updates.
Defaults to True.
eps (float): A value added to the denominator for numerical stability.
Defaults to 1e-5.
Returns:
(tuple[str, nn.Module] | None): The first element is the layer name consisting
of abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer.
"""
if normalization is None or isinstance(normalization, tuple):
# return tuple or None as is
return normalization # type: ignore[return-value]
if isinstance(normalization, nn.Module):
# return nn.Module with empty name string
return "", normalization
if not callable(normalization):
msg = f"normalization must be a callable, but got {type(normalization)}."
raise TypeError(msg)
if isinstance(normalization, partial) and normalization.func.__name__ == "build_norm_layer":
# add arguments to `normalization` and return it
signature = inspect.signature(normalization.func)
predefined_key_list: list[str] = ["kwargs"]
# find keys according to normalization.args except for normalization type (index=0)
predefined_key_list.extend(list(signature.parameters.keys())[: len(normalization.args)])
# find keys according to normalization.keywords
predefined_key_list.extend(list(normalization.keywords.keys()))
# set the remaining parameters previously undefined in normalization
_locals = locals()
fn_kwargs: dict[str, Any] = {
k: _locals.get(k, None) for k in signature.parameters if k not in predefined_key_list
}
# manually update kwargs
fn_kwargs.update(_locals.get("kwargs", {}))
return normalization(**fn_kwargs)
if (layer_type := _get_norm_type(normalization)) not in AVAILABLE_NORMALIZATION_LIST:
msg = f"Unsupported normalization: {layer_type.__name__}."
raise ValueError(msg)
# set norm name
abbr = layer_name or infer_abbr(layer_type)
name = abbr + str(postfix)
# set norm layer
if layer_type is not nn.GroupNorm:
layer = normalization(num_features, eps=eps, **kwargs)
if layer_type == SyncBatchNorm and hasattr(layer, "_specify_ddp_gpu_num"):
layer._specify_ddp_gpu_num(1) # noqa: SLF001
else:
layer = normalization(num_channels=num_features, eps=eps, **kwargs)
for param in layer.parameters():
param.requires_grad = requires_grad
return name, layer
def is_norm(layer: nn.Module) -> bool:
"""Check if a layer is a normalization layer.
Args:
layer (nn.Module): The layer to be checked.
exclude (type | tuple[type]): Types to be excluded.
Returns:
bool: Whether the layer is a norm layer.
"""
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)