Skip to content

Commit

Permalink
Refactoring attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 31, 2023
1 parent 3fd9b91 commit 0dec0c7
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 106 deletions.
78 changes: 39 additions & 39 deletions metnet/layers/PartitionAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
Parameters
----------
X : torch.Tensor
Input tensor of the shape [B, C_in, H, W].
Input tensor of the shape [N, C_in, H, W].
Returns:
-------
torch.Tensor
Output tensor of the shape [B, C_out, H (// 2), W (// 2)].
Output tensor of the shape [N, C_out, H (// 2), W (// 2)].
"""
# Save original shape
_, C, H, W = X.shape
Expand Down Expand Up @@ -227,31 +227,31 @@ def partition_function(self, input: Tensor) -> torch.Tensor:
Parameters
----------
input : Tensor
input (torch.Tensor): Input tensor of the shape [B, C, H, W].
input (torch.Tensor): Input tensor of the shape [N, C, H, W].
Returns:
-------
torch.Tensor
blocks (torch.Tensor): Unfolded input tensor of the shape
[B * blocks, partition_size[0], partition_size[1], C].
[N * blocks, partition_size[0], partition_size[1], C].
"""
# Get size of input
B, C, H, W = input.shape
N, C, H, W = input.shape
# Unfold input
blocks = input.view(
B,
N,
C,
H // self.attn_grid_window_size[0],
self.partition_window_size[0],
W // self.partition_window_size[1],
self.partition_window_size[1],
self.attn_grid_window_size[0],
W // self.attn_grid_window_size[1],
self.attn_grid_window_size[1],
)
# Permute and reshape to
# [B * blocks, self.partition_window_size[0], self.partition_window_size[1], channels]
# [N * blocks, self.attn_grid_window_size[0], self.attn_grid_window_size[1], channels]
blocks = (
blocks.permute(0, 2, 4, 3, 5, 1)
.contiguous()
.view(-1, self.partition_window_size[0], self.partition_window_size[1], C)
.view(-1, self.attn_grid_window_size[0], self.attn_grid_window_size[1], C)
)
return blocks

Expand All @@ -267,33 +267,33 @@ def reverse_function(
----------
partitioned_input : torch.Tensor
Block tensor of the shape
[B * partitioned_input, partition_size[0], partition_size[1], C].
[N * partitioned_input, partition_size[0], partition_size[1], C].
original_size : Tuple[int, int]
Original shape.
Returns:
-------
torch.Tensor
output (torch.Tensor): Folded output tensor of the shape
[B, C, original_size[0], original_size[1]].
[N, C, original_size[0], original_size[1]].
"""
# Get height and width
H, W = original_size
# Compute original batch size
B = int(
N = int(
partitioned_input.shape[0]
/ (H * W / self.partition_window_size[0] / self.partition_window_size[1])
/ (H * W / self.attn_grid_window_size[0] / self.attn_grid_window_size[1])
)
# Fold grid tensor
output = partitioned_input.view(
B,
H // self.partition_window_size[0],
W // self.partition_window_size[1],
self.partition_window_size[0],
self.partition_window_size[1],
N,
H // self.attn_grid_window_size[0],
W // self.attn_grid_window_size[1],
self.attn_grid_window_size[0],
self.attn_grid_window_size[1],
-1,
)
output = output.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, -1, H, W)
output = output.permute(0, 5, 1, 3, 2, 4).contiguous().view(N, -1, H, W)
return output


Expand Down Expand Up @@ -361,30 +361,30 @@ def partition_function(self, input: Tensor) -> torch.Tensor:
Parameters
----------
input : Tensor
Input tensor of the shape [B, C, H, W].
Input tensor of the shape [N, C, H, W].
Returns:
-------
torch.Tensor
Unfolded input tensor of the shape
[B * grids, grid_size[0], grid_size[1], C].
[N * grids, grid_size[0], grid_size[1], C].
"""
# Get size of input
B, C, H, W = input.shape
N, C, H, W = input.shape
# Unfold input
grid = input.view(
B,
N,
C,
self.partition_window_size[0],
H // self.partition_window_size[0],
self.partition_window_size[1],
W // self.partition_window_size[1],
self.attn_grid_window_size[0],
H // self.attn_grid_window_size[0],
self.attn_grid_window_size[1],
W // self.attn_grid_window_size[1],
)
# Permute and reshape [B * (H // self.partition_window_size[0]) * (W // self.partition_window_size[1]), self.partition_window_size[0], window_size[1], C] # noqa
# Permute and reshape [N * (H // self.attn_grid_window_size[0]) * (W // self.attn_grid_window_size[1]), self.attn_grid_window_size[0], window_size[1], C] # noqa
grid = (
grid.permute(0, 3, 5, 2, 4, 1)
.contiguous()
.view(-1, self.partition_window_size[0], self.partition_window_size[1], C)
.view(-1, self.attn_grid_window_size[0], self.attn_grid_window_size[1], C)
)
return grid

Expand All @@ -400,30 +400,30 @@ def reverse_function(
----------
partitioned_input : torch.Tensor
Grid tensor of the shape
[B * partitioned_input, partition_size[0], partition_size[1], C].
[N * partitioned_input, partition_size[0], partition_size[1], C].
original_size : Tuple[int, int]
Original shape.
Returns:
-------
torch.Tensor
Folded output tensor of the shape [B, C, original_size[0], original_size[1]].
Folded output tensor of the shape [N, C, original_size[0], original_size[1]].
"""
# Get height, width, and channels
(H, W), C = original_size, partitioned_input.shape[-1]
# Compute original batch size
B = int(
N = int(
partitioned_input.shape[0]
/ (H * W / self.partition_window_size[0] / self.partition_window_size[1])
/ (H * W / self.attn_grid_window_size[0] / self.attn_grid_window_size[1])
)
# Fold partitioned_input tensor
output = partitioned_input.view(
B,
H // self.partition_window_size[0],
W // self.partition_window_size[1],
N,
H // self.attn_grid_window_size[0],
W // self.attn_grid_window_size[1],
self.partition_window_size[0],
self.partition_window_size[1],
C,
)
output = output.permute(0, 5, 3, 1, 4, 2).contiguous().view(B, C, H, W)
output = output.permute(0, 5, 3, 1, 4, 2).contiguous().view(N, C, H, W)
return output
161 changes: 94 additions & 67 deletions metnet/layers/RelativeSelfAttention.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,27 @@
"""
Relative Self Attention Implementation
"""
from typing import Tuple
from typing import Tuple, Type

import torch
import torch.nn as nn


class RelativeSelfAttention(nn.Module):
"""
Relative Self-Attention similar to Swin V1.
Implementation inspired from ChristophReich1996's MaxViT implementation.
"""

def __init__(
self,
in_channels: int,
num_heads: int = 32,
attn_grid_window_size: Tuple[int, int] = (8, 8),
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_normalised_qk: bool = True,
) -> None:
"""
Constructor Method
Parameters
----------
in_channels : int
Number of input channels
num_heads : int, optional
Number of attention heads, by default 32
attn_grid_window_size : Tuple[int, int], optional
attention grid window size, by default (8, 8)
attn_drop : float, optional
attention dropout rate, by default 0.0
proj_drop : float, optional
post attention projection dropout rate, by default 0.0
use_normalised_qk : bool, by default True
Normalise queries and keys, (as in Metnet 3)
"""
class RelativePositionBias(nn.Module):
def __init__(self, attn_size: Tuple[int, int], num_heads: int) -> None:
super().__init__()

self.in_channels: int = in_channels
self.num_heads: int = num_heads
self.attn_grid_window_size: Tuple[int, int] = attn_grid_window_size
self.scale: float = num_heads**-0.5
self.attn_area: int = attn_grid_window_size[0] * attn_grid_window_size[1]

self.qkv_mapping = nn.Linear(
in_features=in_channels, out_features=3 * in_channels, bias=True
)
self.use_normalised_qk = use_normalised_qk
self.attn_drop = nn.Dropout(p=attn_drop)
self.proj = nn.Linear(in_features=in_channels, out_features=in_channels, bias=True)
self.proj_drop = nn.Dropout(p=proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.attn_size = attn_size
self.num_heads = num_heads

# Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * attn_grid_window_size[0] - 1) * (2 * attn_grid_window_size[1] - 1), num_heads
)
torch.zeros((2 * self.attn_size[0] - 1) * (2 * self.attn_size[1] - 1), self.num_heads)
)

# Get pair-wise relative position index for each token inside the window
self.register_buffer(
"relative_position_index",
self.get_relative_position_index(attn_grid_window_size[0], attn_grid_window_size[1]),
self.get_relative_position_index(self.attn_size[0], self.attn_size[1]),
)

def get_relative_position_index(self, win_h: int, win_w: int) -> torch.Tensor:
Expand Down Expand Up @@ -99,7 +52,7 @@ def get_relative_position_index(self, win_h: int, win_w: int) -> torch.Tensor:
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1)

def _get_relative_positional_bias(self) -> torch.Tensor:
def _get_relative_positional_bias(self, attn_area: int) -> torch.Tensor:
"""
Returns the relative positional bias.
Expand All @@ -110,39 +63,113 @@ def _get_relative_positional_bias(self) -> torch.Tensor:
"""
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(self.attn_area, self.attn_area, -1)
].view(attn_area, attn_area, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
return relative_position_bias.unsqueeze(0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, attn_area: int) -> torch.Tensor:
return self._get_relative_positional_bias(attn_area=attn_area)


class RelativeSelfAttention(nn.Module):
"""
Relative Self-Attention similar to Swin V1.
Implementation inspired from timm's MaxViT implementation.
"""

def __init__(
self,
in_channels: int,
attention_head_dim: int = 512,
num_heads: int = 32,
head_first: bool = False,
attn_grid_window_size: Tuple[int, int] = (8, 8),
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_normalised_qk: bool = True,
rel_attn_bias: Type[nn.Module] = None,
) -> None:
"""
Constructor Method
Parameters
----------
in_channels : int
Number of input channels
num_heads : int, optional
Number of attention heads, by default 32
attn_grid_window_size : Tuple[int, int], optional
attention grid window size, by default (8, 8)
attn_drop : float, optional
attention dropout rate, by default 0.0
proj_drop : float, optional
post attention projection dropout rate, by default 0.0
use_normalised_qk : bool, by default True
Normalise queries and keys, (as in Metnet 3)
"""
super().__init__()

self.in_channels: int = in_channels
self.num_heads: int = num_heads
self.attention_head_dim = attention_head_dim
self.head_first = head_first
self.attn_grid_window_size: Tuple[int, int] = attn_grid_window_size
self.scale: float = num_heads**-0.5
self.attn_area: int = attn_grid_window_size[0] * attn_grid_window_size[1]
self.rel_attn_bias = rel_attn_bias

self.qkv = nn.Conv2d(self.in_channels, self.attention_head_dim * 3, 1)

self.use_normalised_qk = use_normalised_qk
self.attn_drop = nn.Dropout(p=attn_drop)
self.proj = nn.Conv2d(self.attention_head_dim, self.in_channels, 1)
self.proj_drop = nn.Dropout(p=proj_drop)
self.softmax = nn.Softmax(dim=-1)

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Parameters
----------
x : torch.Tensor
input tensor of the shape [B_, N, C].
X : torch.Tensor
input tensor of the shape [N, C, H, W].
Returns:
-------
torch.Tensor
Output tensor of the shape [B_, N, C].
Output tensor of the shape [N, C, H, W].
"""
# Get shape of x
B_, N, _ = x.shape
# Perform query key value mapping
qkv = self.qkv_mapping(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# Get shape of X
B, C, H, W = X.shape

if self.head_first:
q, k, v = (
self.qkv(X).view(B, self.num_heads, self.attention_head_dim * 3, -1).chunk(3, dim=2)
)
else:
q, k, v = (
self.qkv(X).reshape(B, 3, self.num_heads, self.attention_head_dim, -1).unbind(1)
)

if self.use_normalised_qk:
q = torch.nn.functional.normalize(q, dim=1) # TODO: verify dim
k = torch.nn.functional.normalize(k, dim=1) # TODO: verify dim

# q = q * self.scale # TODO: verify if this should be applied after norm

# Compute attention maps
attn = self.softmax(q @ k.transpose(-2, -1) + self._get_relative_positional_bias())
attn = q.transpose(-2, -1) @ k
if self.rel_attn_bias is not None:
attn = attn + self.rel_attn_bias()

attn = self.softmax(attn)

attn = self.attn_drop(attn)

# Map value with attention maps
output = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
output = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
# Perform final projection and dropout
output = self.proj(output) # TODO: Check if this is needed
output = self.proj_drop(output)
Expand Down
Loading

0 comments on commit 0dec0c7

Please sign in to comment.