diff --git a/metnet/layers/RelativeSelfAttention.py b/metnet/layers/RelativeSelfAttention.py new file mode 100644 index 0000000..284e1d7 --- /dev/null +++ b/metnet/layers/RelativeSelfAttention.py @@ -0,0 +1,105 @@ +from typing import Tuple + +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ + + +class RelativeSelfAttention(nn.Module): + """Relative Self-Attention similar to Swin V1. Implementation taken from ChristophReich1996's MaxViT implementation.""" + + def __init__( + self, + in_channels: int, + num_heads: int = 32, + grid_window_size: Tuple[int, int] = (7, 7), + attn_drop: float = 0.0, + drop: float = 0.0, + ) -> None: + super().__init__() + + self.in_channels: int = in_channels + self.num_heads: int = num_heads + self.grid_window_size: Tuple[int, int] = grid_window_size + + self.scale: float = num_heads**-0.5 + self.attn_area: int = grid_window_size[0] * grid_window_size[1] + + # Init layers + self.qkv_mapping = nn.Linear( + in_features=in_channels, out_features=3 * in_channels, bias=True + ) + self.attn_drop = nn.Dropout(p=attn_drop) + self.proj = nn.Linear(in_features=in_channels, out_features=in_channels, bias=True) + self.proj_drop = nn.Dropout(p=drop) + self.softmax = nn.Softmax(dim=-1) + + # Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * grid_window_size[0] - 1) * (2 * grid_window_size[1] - 1), num_heads) + ) + + # Get pair-wise relative position index for each token inside the window + self.register_buffer( + "relative_position_index", + self.get_relative_position_index(grid_window_size[0], grid_window_size[1]), + ) + # Init relative positional bias + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def get_relative_position_index(self, win_h: int, win_w: int) -> torch.Tensor: + """Function to generate pair-wise relative position index for each token inside the window. + Taken from Timms Swin V1 implementation. + + Args: + win_h (int): Window/Grid height. + win_w (int): Window/Grid width. + + Returns: + relative_coords (torch.Tensor): Pair-wise relative position indexes [height * width, height * width]. + """ + coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += win_h - 1 + relative_coords[:, :, 1] += win_w - 1 + relative_coords[:, :, 0] *= 2 * win_w - 1 + return relative_coords.sum(-1) + + def _get_relative_positional_bias(self) -> torch.Tensor: + """Returns the relative positional bias. + + Returns: + relative_position_bias (torch.Tensor): Relative positional bias. + """ + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view(self.attn_area, self.attn_area, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape [B_, N, C]. + + Returns: + output (torch.Tensor): Output tensor of the shape [B_, N, C]. + """ + # Get shape of input + B_, N, C = input.shape + # Perform query key value mapping + qkv = self.qkv_mapping(input).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + # Scale query + q = q * self.scale + # Compute attention maps + attn = self.softmax(q @ k.transpose(-2, -1) + self._get_relative_positional_bias()) + # Map value with attention maps + output = (attn @ v).transpose(1, 2).reshape(B_, N, -1) + # Perform final projection and dropout + output = self.proj(output) + output = self.proj_drop(output) + return output