Skip to content

Commit

Permalink
Adding Relative Self Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 24, 2023
1 parent 5bec075 commit 1db1bab
Showing 1 changed file with 105 additions and 0 deletions.
105 changes: 105 additions & 0 deletions metnet/layers/RelativeSelfAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Tuple

import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_


class RelativeSelfAttention(nn.Module):
"""Relative Self-Attention similar to Swin V1. Implementation taken from ChristophReich1996's MaxViT implementation."""

def __init__(
self,
in_channels: int,
num_heads: int = 32,
grid_window_size: Tuple[int, int] = (7, 7),
attn_drop: float = 0.0,
drop: float = 0.0,
) -> None:
super().__init__()

self.in_channels: int = in_channels
self.num_heads: int = num_heads
self.grid_window_size: Tuple[int, int] = grid_window_size

self.scale: float = num_heads**-0.5
self.attn_area: int = grid_window_size[0] * grid_window_size[1]

# Init layers
self.qkv_mapping = nn.Linear(
in_features=in_channels, out_features=3 * in_channels, bias=True
)
self.attn_drop = nn.Dropout(p=attn_drop)
self.proj = nn.Linear(in_features=in_channels, out_features=in_channels, bias=True)
self.proj_drop = nn.Dropout(p=drop)
self.softmax = nn.Softmax(dim=-1)

# Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * grid_window_size[0] - 1) * (2 * grid_window_size[1] - 1), num_heads)
)

# Get pair-wise relative position index for each token inside the window
self.register_buffer(
"relative_position_index",
self.get_relative_position_index(grid_window_size[0], grid_window_size[1]),
)
# Init relative positional bias
trunc_normal_(self.relative_position_bias_table, std=0.02)

def get_relative_position_index(self, win_h: int, win_w: int) -> torch.Tensor:
"""Function to generate pair-wise relative position index for each token inside the window.
Taken from Timms Swin V1 implementation.
Args:
win_h (int): Window/Grid height.
win_w (int): Window/Grid width.
Returns:
relative_coords (torch.Tensor): Pair-wise relative position indexes [height * width, height * width].
"""
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += win_h - 1
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1)

def _get_relative_positional_bias(self) -> torch.Tensor:
"""Returns the relative positional bias.
Returns:
relative_position_bias (torch.Tensor): Relative positional bias.
"""
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(self.attn_area, self.attn_area, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
return relative_position_bias.unsqueeze(0)

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape [B_, N, C].
Returns:
output (torch.Tensor): Output tensor of the shape [B_, N, C].
"""
# Get shape of input
B_, N, C = input.shape
# Perform query key value mapping
qkv = self.qkv_mapping(input).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# Scale query
q = q * self.scale
# Compute attention maps
attn = self.softmax(q @ k.transpose(-2, -1) + self._get_relative_positional_bias())
# Map value with attention maps
output = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
# Perform final projection and dropout
output = self.proj(output)
output = self.proj_drop(output)
return output

0 comments on commit 1db1bab

Please sign in to comment.