diff --git a/metnet/layers/MultiheadSelfAttention2D.py b/metnet/layers/MultiheadSelfAttention2D.py index 6f3a100..66d9024 100644 --- a/metnet/layers/MultiheadSelfAttention2D.py +++ b/metnet/layers/MultiheadSelfAttention2D.py @@ -107,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attention_weights = Q @ K # Attn shape [N, self.num_heads, H*W, H*W] if self.rel_attn_bias is not None: - self.rel_attn_bias(attention_weights) + attention_weights = attention_weights + self.rel_attn_bias() attention_weights = attention_weights.softmax( dim=-1 diff --git a/metnet/layers/PartitionAttention.py b/metnet/layers/PartitionAttention.py index f7c1727..1e707c3 100644 --- a/metnet/layers/PartitionAttention.py +++ b/metnet/layers/PartitionAttention.py @@ -74,7 +74,7 @@ def __init__( # Save parameters self.attn_grid_window_size: Tuple[int, int] = attn_grid_window_size - RelativePositionBias(attn_size=attn_grid_window_size, num_heads=num_heads) + rel_attn_bias = RelativePositionBias(attn_size=attn_grid_window_size, num_heads=num_heads) # Init layers self.attention = MultiheadSelfAttention2D( in_channels=in_channels, @@ -83,7 +83,7 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, use_normalised_qk=use_normalised_qk, - rel_attn_bias=None, + rel_attn_bias=rel_attn_bias, ) self.pre_norm_layer = pre_norm_layer(attn_grid_window_size) # Norm along windows self.post_norm_layer = post_norm_layer(attn_grid_window_size) diff --git a/metnet/layers/RelativePositionBias.py b/metnet/layers/RelativePositionBias.py new file mode 100644 index 0000000..2ad537b --- /dev/null +++ b/metnet/layers/RelativePositionBias.py @@ -0,0 +1,99 @@ +""" +Implementation of Relative Position Bias +""" +from typing import Tuple + +import torch +import torch.nn as nn + + +class RelativePositionBias(nn.Module): + """ + Relative Postition Bias + + Inspired by timm's maxxvit implementation + """ + + def __init__(self, attn_size: Tuple[int, int], num_heads: int) -> None: + """ + Constructor Method + + Parameters + ---------- + attn_size : Tuple[int, int] + Size of the attention window + num_heads : int + Number of heads in the multiheaded attention + """ + super().__init__() + self.attn_size = attn_size + self.attn_area = self.attn_size[0] * self.attn_size[1] + self.num_heads = num_heads + + # Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.attn_size[0] - 1) * (2 * self.attn_size[1] - 1), + self.num_heads, + ) + ) + + # Get pair-wise relative position index for each token inside the window + self.register_buffer( + "relative_position_index", + self.get_relative_position_index(self.attn_size[0], self.attn_size[1]), + ) + + def get_relative_position_index(self, win_h: int, win_w: int) -> torch.Tensor: + """ + Function to generate pair-wise relative position index for each token inside the window. + + Taken from Timms Swin V1 implementation. + + Parameters + ---------- + win_h : int + Window/Grid height. + win_w : int + Window/Grid width. + + Returns: + ------- + torch.Tensor + relative_coords (torch.Tensor): Pair-wise relative position indexes + [height * width, height * width]. + """ + coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += win_h - 1 + relative_coords[:, :, 1] += win_w - 1 + relative_coords[:, :, 0] *= 2 * win_w - 1 + return relative_coords.sum(-1) + + def _get_relative_positional_bias(self) -> torch.Tensor: + """ + Returns the relative positional bias. + + Returns: + ------- + torch.Tensor + relative_position_bias (torch.Tensor): Relative positional bias. + """ + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view(self.attn_area, self.attn_area, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward(self) -> torch.Tensor: + """ + Forward Method + + Returns: + ------- + torch.Tensor + Pairwise relative position bias + """ + return self._get_relative_positional_bias()