Skip to content

Commit

Permalink
Adds Relative position bias
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Nov 2, 2023
1 parent f49113d commit 8f6aec4
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 3 deletions.
2 changes: 1 addition & 1 deletion metnet/layers/MultiheadSelfAttention2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
attention_weights = Q @ K # Attn shape [N, self.num_heads, H*W, H*W]

if self.rel_attn_bias is not None:
self.rel_attn_bias(attention_weights)
attention_weights = attention_weights + self.rel_attn_bias()

attention_weights = attention_weights.softmax(
dim=-1
Expand Down
4 changes: 2 additions & 2 deletions metnet/layers/PartitionAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
# Save parameters
self.attn_grid_window_size: Tuple[int, int] = attn_grid_window_size

RelativePositionBias(attn_size=attn_grid_window_size, num_heads=num_heads)
rel_attn_bias = RelativePositionBias(attn_size=attn_grid_window_size, num_heads=num_heads)
# Init layers
self.attention = MultiheadSelfAttention2D(
in_channels=in_channels,
Expand All @@ -83,7 +83,7 @@ def __init__(
attn_drop=attn_drop,
proj_drop=proj_drop,
use_normalised_qk=use_normalised_qk,
rel_attn_bias=None,
rel_attn_bias=rel_attn_bias,
)
self.pre_norm_layer = pre_norm_layer(attn_grid_window_size) # Norm along windows
self.post_norm_layer = post_norm_layer(attn_grid_window_size)
Expand Down
99 changes: 99 additions & 0 deletions metnet/layers/RelativePositionBias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Implementation of Relative Position Bias
"""
from typing import Tuple

import torch
import torch.nn as nn


class RelativePositionBias(nn.Module):
"""
Relative Postition Bias
Inspired by timm's maxxvit implementation
"""

def __init__(self, attn_size: Tuple[int, int], num_heads: int) -> None:
"""
Constructor Method
Parameters
----------
attn_size : Tuple[int, int]
Size of the attention window
num_heads : int
Number of heads in the multiheaded attention
"""
super().__init__()
self.attn_size = attn_size
self.attn_area = self.attn_size[0] * self.attn_size[1]
self.num_heads = num_heads

# Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.attn_size[0] - 1) * (2 * self.attn_size[1] - 1),
self.num_heads,
)
)

# Get pair-wise relative position index for each token inside the window
self.register_buffer(
"relative_position_index",
self.get_relative_position_index(self.attn_size[0], self.attn_size[1]),
)

def get_relative_position_index(self, win_h: int, win_w: int) -> torch.Tensor:
"""
Function to generate pair-wise relative position index for each token inside the window.
Taken from Timms Swin V1 implementation.
Parameters
----------
win_h : int
Window/Grid height.
win_w : int
Window/Grid width.
Returns:
-------
torch.Tensor
relative_coords (torch.Tensor): Pair-wise relative position indexes
[height * width, height * width].
"""
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += win_h - 1
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1)

def _get_relative_positional_bias(self) -> torch.Tensor:
"""
Returns the relative positional bias.
Returns:
-------
torch.Tensor
relative_position_bias (torch.Tensor): Relative positional bias.
"""
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(self.attn_area, self.attn_area, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
return relative_position_bias.unsqueeze(0)

def forward(self) -> torch.Tensor:
"""
Forward Method
Returns:
-------
torch.Tensor
Pairwise relative position bias
"""
return self._get_relative_positional_bias()

0 comments on commit 8f6aec4

Please sign in to comment.