Source code for otx.algo.modules.activation
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""Custom activation implementation copied from mmcv.cnn.bricks.swish.py."""
from __future__ import annotations
from functools import partial
from typing import Callable
import torch
from torch import Tensor, nn
class Swish(nn.Module):
"""Swish Module.
This module applies the swish function:
.. math::
Swish(x) = x * Sigmoid(x)
Returns:
Tensor: The output tensor.
"""
def forward(self, x: Tensor) -> Tensor:
"""Forward function.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
return x * torch.sigmoid(x)
AVAILABLE_ACTIVATION_LIST: list[nn.Module] = [
nn.ReLU,
nn.LeakyReLU,
nn.PReLU,
nn.RReLU,
nn.ReLU6,
nn.ELU,
nn.Sigmoid,
nn.Tanh,
nn.SiLU,
nn.GELU,
Swish,
]
ACTIVATION_LIST_NOT_SUPPORTING_INPLACE: list[nn.Module] = [
nn.Tanh,
nn.PReLU,
nn.Sigmoid,
Swish,
nn.GELU,
]
def _get_act_type(activation: Callable[..., nn.Module]) -> type:
"""Get class type or name of given activation callable.
Args:
activation (Callable[..., nn.Module]): Activation layer module.
Returns:
(type): Class type of given activation callable.
"""
return activation.func if isinstance(activation, partial) else activation # type: ignore[return-value]
[docs]
def build_activation_layer(
activation: Callable[..., nn.Module] | nn.Module | None,
inplace: bool = True,
) -> nn.Module | None:
"""Build activation layer.
Args:
activation (Callable[..., nn.Module] | nn.Module | None): Activation layer module.
If None or pre-instanstiated module is given, return it as is.
If callable is given, create the layer.
inplace (bool): Whether to use inplace mode for activation.
Default: True.
Returns:
nn.Module: Created activation layer.
"""
if activation is None or isinstance(activation, nn.Module):
return activation
if (layer_type := _get_act_type(activation)) not in AVAILABLE_ACTIVATION_LIST:
msg = f"Unsupported activation: {layer_type.__name__}."
raise ValueError(msg)
layer = activation()
# update inplace
if layer.__class__ not in ACTIVATION_LIST_NOT_SUPPORTING_INPLACE:
layer.inplace = inplace
return layer