""" TensorMONK :: layers :: Activations """
__all__ = ["Activations"]
import torch
import torch.nn as nn
import torch.nn.functional as F
def maxout(tensor: torch.Tensor) -> torch.Tensor:
if not tensor.size(1) % 2 == 0:
raise ValueError("MaxOut: tensor.size(1) must be divisible by n_splits"
": {}".format(tensor.size(1)))
return torch.max(*tensor.split(tensor.size(1)//2, 1))
[docs]class Activations(nn.Module):
r"""Activation functions. Additional activation functions (other than those
available in pytorch) are
:obj:`"hsigm"` & :obj:`"hswish"` (`"Searching for MobileNetV3"
<https://arxiv.org/pdf/1905.02244>`_),
:obj:`"maxo"` (`"Maxout Networks" <https://arxiv.org/pdf/1302.4389>`_),
:obj:`"mish"` (`"Mish: A Self Regularized Non-Monotonic Neural Activation
Function" <https://arxiv.org/pdf/1908.08681v1>`_),
:obj:`"squash"` (`"Dynamic Routing Between Capsules"
<https://arxiv.org/abs/1710.09829>`_) and
:obj:`"swish"` (`"SWISH: A Self-Gated Activation Function"
<https://arxiv.org/pdf/1710.05941v1>`_).
Args:
tensor_size (tuple, required): Input tensor shape in BCHW
(None/any integer >0, channels, height, width).
activation (str, optional): The list of activation options are
:obj:`"elu"`, :obj:`"gelu"`, :obj:`"hsigm"`, :obj:`"hswish"`,
:obj:`"lklu"`, :obj:`"maxo"`, :obj:`"mish"`, :obj:`"prelu"`,
:obj:`"relu"`, :obj:`"relu6"`, :obj:`"rmxo"`, :obj:`"selu"`,
:obj:`"sigm"`, :obj:`"squash"`, :obj:`"swish"`, :obj:`"tanh"`.
(default: :obj:`"relu"`)
elu_alpha (float, optional): (default: :obj:`1.0`)
lklu_negslope (float, optional): (default: :obj:`0.01`)
.. code-block:: python
import torch
import tensormonk
print(tensormonk.activations.Activations.METHODS)
tensor_size = (None, 16, 4, 4)
activation = "maxo"
maxout = tensormonk.activations.Activations(tensor_size, activation)
maxout(torch.randn(1, *tensor_size[1:]))
tensor_size = (None, 16, 4)
activation = "squash"
squash = tensormonk.activations.Activations(tensor_size, activation)
squash(torch.randn(1, *tensor_size[1:]))
tensor_size = (None, 16)
activation = "swish"
swish = tensormonk.activations.Activations(tensor_size, activation)
swish(torch.randn(1, *tensor_size[1:]))
"""
METHODS = ["elu", "gelu", "hsigm", "hswish", "lklu", "maxo", "mish",
"prelu", "relu", "relu6", "rmxo",
"selu", "sigm", "squash", "swish", "tanh"]
def __init__(self, tensor_size: tuple, activation: str = "relu", **kwargs):
super(Activations, self).__init__()
if activation is not None:
activation = activation.lower()
self.t_size = tensor_size
self.activation = activation
self.function = None
if activation not in self.METHODS:
raise ValueError("activation: Invalid activation " +
"/".join(self.METHODS) +
": {}".format(activation))
self.function = getattr(self, "_" + activation)
if activation == "prelu":
self.weight = nn.Parameter(torch.ones(1) * 0.1)
if activation == "lklu":
self.negslope = kwargs["lklu_negslope"] if "lklu_negslope" in \
kwargs.keys() else 0.01
if activation == "elu":
self.alpha = kwargs["elu_alpha"] if "elu_alpha" in \
kwargs.keys() else 1.0
self.tensor_size = tensor_size
if activation in ("maxo", "rmxo"):
t_size = list(tensor_size)
t_size[1] = t_size[1] // 2
self.tensor_size = tuple(t_size)
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
if self.function is None:
return tensor
return self.function(tensor)
def _relu(self, tensor: torch.Tensor):
return F.relu(tensor)
def _relu6(self, tensor: torch.Tensor):
return F.relu6(tensor)
def _lklu(self, tensor: torch.Tensor):
return F.leaky_relu(tensor, self.negslope)
def _elu(self, tensor: torch.Tensor):
return F.elu(tensor, self.alpha)
def _gelu(self, tensor: torch.Tensor):
return F.gelu(tensor)
def _prelu(self, tensor: torch.Tensor):
return F.prelu(tensor, self.weight)
def _selu(self, tensor: torch.Tensor):
return F.selu(tensor)
def _tanh(self, tensor: torch.Tensor):
return torch.tanh(tensor)
def _sigm(self, tensor: torch.Tensor):
return torch.sigmoid(tensor)
def _maxo(self, tensor: torch.Tensor):
if not tensor.size(1) % 2 == 0:
raise ValueError("MaxOut: tensor.size(1) must be divisible by 2"
": {}".format(tensor.size(1)))
return torch.max(*tensor.split(tensor.size(1)//2, 1))
def _rmxo(self, tensor: torch.Tensor):
return self._maxo(F.relu(tensor))
def _swish(self, tensor: torch.Tensor):
return tensor * torch.sigmoid(tensor)
def _mish(self, tensor: torch.Tensor):
return tensor * F.softplus(tensor).tanh()
def _squash(self, tensor: torch.Tensor):
if not tensor.dim() == 3:
raise ValueError("Squash requires 3D tensors: {}".format(
tensor.dim()))
sum_squares = (tensor ** 2).sum(2, True)
return (sum_squares/(1+sum_squares)) * tensor / sum_squares.pow(0.5)
def _hsigm(self, tensor: torch.Tensor):
return F.relu6(tensor + 3) / 6
def _hswish(self, tensor: torch.Tensor):
return self._hsigm(tensor) * tensor
def __repr__(self):
return self.activation
@staticmethod
def available() -> list:
return Activations.METHODS
def flops(self) -> int:
import numpy as np
flops = 0
numel = np.prod(self.t_size[1:])
if self.activation == "elu":
# max(0, x) + min(0, alpha*(exp(x)-1))
flops = numel * 5
elif self.activation in ("lklu", "prelu", "sigm"):
flops = numel * 3
elif self.activation == "maxo":
# torch.max(*x.split(x.size(1)//2, 1))
flops = numel / 2
elif self.activation == "mish":
# x * tanh(ln(1 + e^x))
flops = numel * 5
elif self.activation == "relu":
# max(0, x)
flops = numel
elif self.activation == "relu6":
# min(6, max(0, x))
flops = numel * 2
elif self.activation == "rmxo":
# maxo(relu(x))
flops = int(numel * 1.5)
elif self.activation == "squash":
# sum_squares = (tensor**2).sum(2, True)
# (sum_squares/(1+sum_squares)) * tensor / sum_squares.pow(0.5)
flops = numel * 4 + self.t_size[1] * 2
elif self.activation == "swish":
# x * sigm(x)
flops = numel * 4
elif self.activation == "tanh":
# (exp(x) - exp(-x)) / (exp(x) + exp(-x))
flops = numel * 9
elif self.activation == "hsigm":
# min(6, max(0, x + 3)) / 6
flops = numel * 4
elif self.activation == "hswish":
# x * min(6, max(0, x + 3)) / 6
flops = numel * 8
return flops