Source code for tensormonk.layers.attention

""" TensorMONK :: layers :: attention's """

__all__ = ["SelfAttention", "LocalAttention",
           "Attention", "ResidualAttention"]

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union
from .convolution import Convolution
from .utils import compute_flops


[docs]class SelfAttention(nn.Module): r"""Self-Attention (`"Self-Attention Generative Adversarial Networks" <https://arxiv.org/pdf/1805.08318.pdf>`_). Args: tensor_size (tuple, required): Input tensor shape in BCHW (None/any integer >0, channels, height, width). shrink (int, optional): Used to compute output channels of key and query, i.e, int(tensor_size[1] / shrink), (default = :obj:`8`). scale_factor (float, optional): Scale at which attention is computed. (use scale_factor <1 for speed). When scale_factor != 1, input is scaled using nearest neighbor interpolation (default = :obj:`1`). return_attention (bool, optional): When True, returns a tuple (output, attention) (default = :obj:`False`). :rtype: :class:`torch.Tensor` """ def __init__(self, tensor_size: tuple, shrink: int = 8, scale_factor: float = 1., return_attention: bool = False, **kwargs): super(SelfAttention, self).__init__() if not isinstance(tensor_size, (list, tuple)): raise TypeError("SelfAttention: tensor_size must be tuple/list: " "{}".format(type(tensor_size).__name__)) tensor_size = tuple(tensor_size) if not len(tensor_size) == 4: raise ValueError("SelfAttention: tensor_size must be of length 4" ": {}".format(len(tensor_size))) if not isinstance(shrink, int): raise TypeError("SelfAttention: shrink must be int: " "{}".format(type(shrink).__name__)) if not (tensor_size[1] >= shrink >= 1): raise TypeError("SelfAttention: shrink must be tensor_size[1] >= " "shrink > 0: {}".format(shrink)) self.shrink = shrink if not isinstance(scale_factor, float): raise TypeError("SelfAttention: scale_factor must be float: " "{}".format(type(scale_factor).__name__)) self.scale_factor = scale_factor if not isinstance(return_attention, bool): raise TypeError("SelfAttention: return_attention must be bool: " "{}".format(type(return_attention).__name__)) self.return_attention = return_attention self.oc = int(tensor_size[1] / shrink) self.key = Convolution(tensor_size, 1, self.oc, 1, True, None) self.query = Convolution(tensor_size, 1, self.oc, 1, True, None) self.value = Convolution(tensor_size, 1, tensor_size[1], 1, True, None) self.gamma = nn.Parameter(torch.zeros(1)) self.tensor_size = tensor_size def forward(self, tensor: torch.Tensor): if self.scale_factor != 1: o = F.interpolate(tensor, scale_factor=self.scale_factor) _tensor = tensor.clone() tensor = F.interpolate(tensor, scale_factor=self.scale_factor) n, c, h, w = tensor.shape key = self.key(tensor).view(n, -1, h*w) query = self.query(tensor).view(n, -1, h*w) value = self.value(tensor).view(n, -1, h*w) attention = F.softmax(torch.bmm(query.permute(0, 2, 1), key), dim=2) o = torch.bmm(value, attention.permute(0, 2, 1)).view(n, c, h, w) if self.scale_factor != 1: o = F.interpolate(o, size=_tensor.shape[2:]) tensor = _tensor if self.return_attention: return tensor + o * self.gamma, attention return tensor + o * self.gamma def flops(self): flops = 0 c, h, w = self.tensor_size[1:] if self.scale_factor != 1: # assuming nearest nh, nw = int(h*self.scale_factor), int(w*self.scale_factor) flops += (c*h*w + c*nh*nw) * 2 # attention - bmm flops += ((2 * self.oc * self.oc) - 1) * ((h * w)**2) # attention - softmax flops += (h * w) * (h * w * 3) # o - bmm flops += c * ((2 * h * w) - 1) * h * w # tensor + o*gamma flops += c * h * w * 2 return compute_flops(self) + flops
[docs]class LocalAttention(nn.Module): r"""LocalAttention (`"Stand-Alone Self-Attention in Vision Models" <https://arxiv.org/pdf/1906.05909.pdf>`_). Args: tensor_size (tuple, required): Input tensor shape in BCHW (None/any integer >0, channels, height, width). filter_size (int/tuple, required): size of kernel, integer or list/tuple of length 2. out_channels (int, required): output tensor.size(1) strides (int/tuple, optional): convolution stride (default = :obj:`1`). groups (int, optional): enables grouped convolution (default = :obj:`4`). bias (bool): When True, key, query & value 1x1 convolutions have bias (default = :obj:`False`). replicate_paper (bool, optional): When False, relative attention logic is different from that of paper (default = :obj:`True`). normalize_offset (bool, optional): When True (and replicate_paper = :obj:`False`), normalizes the row and column offsets (default = :obj:`False`). :rtype: :class:`torch.Tensor` """ def __init__(self, tensor_size: tuple, filter_size: Union[int, tuple], out_channels: int, strides: int = 1, groups: int = 4, bias: bool = False, replicate_paper: bool = True, normalize_offset: bool = False, **kwargs): super(LocalAttention, self).__init__() if not isinstance(tensor_size, (list, tuple)): raise TypeError("LocalAttention: tensor_size must be tuple/list: " "{}".format(type(tensor_size).__name__)) tensor_size = tuple(tensor_size) if not len(tensor_size) == 4: raise ValueError("LocalAttention: tensor_size must be of length 4" ": {}".format(len(tensor_size))) if not isinstance(filter_size, (int, list, tuple)): raise TypeError("LocalAttention: filter_size must be int/tuple/" "list: {}".format(type(filter_size).__name__)) if isinstance(filter_size, int): filter_size = (filter_size, filter_size) filter_size = tuple(filter_size) if not len(filter_size) == 2: raise ValueError("LocalAttention: filter_size must be of length 2" ": {}".format(len(filter_size))) if not isinstance(out_channels, int): raise TypeError("LocalAttention: out_channels must be int: " "{}".format(type(out_channels).__name__)) if not out_channels >= 1: raise ValueError("LocalAttention: out_channels must be >= 1: " "{}".format(len(out_channels))) if not isinstance(strides, (int, list, tuple)): raise TypeError("LocalAttention: strides must be int/tuple/list: " "{}".format(type(strides).__name__)) if isinstance(strides, int): strides = (strides, strides) strides = tuple(strides) if not len(strides) == 2: raise ValueError("LocalAttention: strides must be of length 2: " "{}".format(len(strides))) if not isinstance(groups, int): raise TypeError("LocalAttention: groups must be int: " "{}".format(type(groups).__name__)) if out_channels % groups or groups < 1: raise ValueError("LocalAttention: groups must be divisible by " "out_channels and >=1: {}".format(groups)) if not isinstance(bias, bool): raise TypeError("LocalAttention: bias must be bool: " "{}".format(type(bias).__name__)) if not isinstance(replicate_paper, bool): raise TypeError("LocalAttention: replicate_paper must be bool: " "{}".format(type(replicate_paper).__name__)) if not isinstance(normalize_offset, bool): raise TypeError("LocalAttention: normalize_offset must be bool: " "{}".format(type(normalize_offset).__name__)) self.fs = filter_size self.st = strides self.gs = groups self.replicate_paper = replicate_paper ic = tensor_size[1] # 1x1 convolutions for spatial-relative attention self.query = nn.Conv2d(ic, out_channels, 1, self.st, bias=bias, groups=groups) self.key = nn.Conv2d(ic, out_channels, 1, bias=bias, groups=groups) self.value = nn.Conv2d(ic, out_channels, 1, bias=bias, groups=groups) torch.nn.init.kaiming_normal_(self.query.weight) torch.nn.init.kaiming_normal_(self.key.weight) torch.nn.init.kaiming_normal_(self.value.weight) fh, fw = self.fs self.pad = (fw // 2 - int(fw % 2 == 0), fw // 2, fh // 2 - int(fh % 2 == 0), fh // 2) # relative attention offset = torch.arange(fh).view(fh, 1).repeat(1, fw) - self.pad[2] self.register_buffer("row_offset", offset.view(-1).float()) offset = torch.arange(fw).view(1, fw).repeat(fh, 1) - self.pad[0] self.register_buffer("col_offset", offset.view(-1).float()) if normalize_offset and not replicate_paper: self.row_offset.data.div_(self.row_offset.abs().max()) self.col_offset.data.div_(self.col_offset.abs().max()) if replicate_paper: # as per paper self.row_w = nn.Parameter(torch.rand(fh*fw, out_channels//2)) self.col_w = nn.Parameter(torch.rand(fh*fw, out_channels - out_channels//2)) else: # made more logical sense self.row_w = nn.Parameter(torch.rand(fh*fw, out_channels)) self.col_w = nn.Parameter(torch.rand(fh*fw, out_channels)) torch.nn.init.normal_(self.row_w, 0, 1) torch.nn.init.normal_(self.col_w, 0, 1) self.in_size = tensor_size self.tensor_size = ( None, out_channels, (tensor_size[2] + self.pad[2] + self.pad[3]) / self.st[0], (tensor_size[3] + self.pad[0] + self.pad[1]) / self.st[1]) def forward(self, tensor: torch.Tensor) -> torch.Tensor: n, c, h, w = tensor.shape fh, fw = self.fs # key, query and value k, q, v = self.key(tensor), self.query(tensor), self.value(tensor) oc, nh, nw = q.shape[1:] q = F.unfold(q, 1, padding=0, stride=1) q = q.view(n, oc, 1, nh, nw).contiguous() k = F.unfold(F.pad(k, self.pad), self.fs, stride=self.st) k = k.view(n, oc, fh * fw, nh, nw).contiguous() v = F.unfold(F.pad(v, self.pad), self.fs, stride=self.st) v = v.view(n, oc, fh * fw, nh, nw).contiguous() # encoding offsets if self.replicate_paper: # as per paper r_ai_bi = torch.cat((self.row_offset @ self.row_w, self.col_offset @ self.col_w)) else: # made more logical sense r_ai_bi = ((self.row_offset @ self.row_w) + (self.col_offset @ self.col_w)) r_ai_bi = r_ai_bi.view(1, oc, 1, 1, 1).contiguous() # equation 3 - spatial-relative attention attention = (F.softmax(q * k + q * r_ai_bi, dim=2) * v).sum(dim=2) return attention def flops(self): flops = 0 # key and value c, h, w = self.in_size[1:] nc, nh, nw = self.tensor_size[1:] flops += nc * c / self.gs * h * w flops += nc * c / self.gs * h * w # query flops += nc * c / self.gs * nh * nw # encoding flops += (self.row_offset.numel() * 2) * self.row_w.shape[-1] flops += (self.row_offset.numel() * 2) * self.row_w.shape[-1] # attention flops += (c * h * w * self.fs[0] * self.fs[1]) * 6 return int(flops)
class Attention(nn.Module): r"""Attention (`"Attention is all you need." <https://arxiv.org/pdf/1706.03762.pdf>`_). Args: features (int, required): Number of input features. heads (int, required): Number of heads. bias (bool, optional): When True, key, query & value 1x1 convolutions have bias (default = :obj:`False`). p (float, optional): Dropout layer probability (default = :obj:`0.1`). size_hw (tuple, optional): Enables positional encoding when a tuple of (height, width) is provided (default = :obj:`None`). pre_embedding (int, optional): Adds zeros for additional embeddings (default = :obj:`0`). :rtype: :class:`torch.Tensor` """ def __init__(self, features: int, heads: int, bias: bool = False, p: float = 0.1, size_hw: tuple = None, pre_embedding: int = 0): super(Attention, self).__init__() # params self.features: int = features self.heads: int = heads self.p: float = p self.scale: float = (features // heads) ** 0.5 self.size_hw: tuple = size_hw # positional information if size_hw is not None: h, w = size_hw h_grid, w_grid = torch.meshgrid(torch.linspace(-1, 1, h), torch.linspace(-1, 1, w)) h_grid, w_grid = h_grid.reshape(-1), w_grid.reshape(-1) if pre_embedding > 0: h_grid = torch.cat((torch.zeros(pre_embedding), h_grid)) w_grid = torch.cat((torch.zeros(pre_embedding), w_grid)) self.register_buffer( "positions", torch.stack((h_grid, w_grid), -1)[None]) # attention layer self.kqv = nn.Linear(features + (0 if size_hw is None else 2), features * 3, bias=bias) def forward(self, tensor: torch.Tensor): (b, t, nf), heads = tensor.shape, self.heads if hasattr(self, "positions"): tensor = torch.cat( (tensor, self.positions.expand(b, -1, -1)), -1) # key, query and value k, q, v = map(lambda x: x.reshape(b, t, heads, -1).transpose(1, 2), self.kqv(tensor).split(nf, dim=2)) attn = (q @ k.transpose(-2, -1)) / self.scale # dropout if self.training and self.p > 0: attn = F.dropout(attn, p=self.p) attn_probs = attn.softmax(-1) # context context = (attn_probs @ v).transpose(1, 2).reshape(b, t, nf) return context def __repr__(self): msg = f"Attention: features={self.features} and " msg += f"heads={self.heads}" if hasattr(self, "positions"): msg += f" (positional={self.size_hw[0]}x{self.size_hw[1]})" return msg class ResidualAttention(nn.Module): r"""Residual Attention. Args: Refer :obj:`Attention`. :rtype: :class:`torch.Tensor` """ def __init__(self, features: int, heads: int, bias: bool = False, p: float = 0.1, size_hw: tuple = None, pre_embedding: int = 0): super(ResidualAttention, self).__init__() self.p: float = p self.attention = Attention( features, heads, bias, p, size_hw, pre_embedding) self.projection = nn.Linear(features, features, bias=bias) self.normalize = nn.LayerNorm(features) def forward(self, tensor: torch.Tensor): context = self.attention(tensor) projected = self.projection(context) if self.training and self.p > 0: projected = F.dropout(projected, p=self.p) return self.normalize(tensor + projected) # from tensormonk.layers import Convolution # from tensormonk.layers.utils import compute_flops # tensor_size = (3, 16, 60, 60) # x = torch.rand(*tensor_size) # test = SelfAttention(tensor_size, 8, 1.) # test(x)[1].shape # %timeit test(x)[1].shape # test = SelfAttention(tensor_size, 8, 0.25) # test(x)[1].shape # %timeit test(x)[1].shape