Skip to content

Commit

Permalink
Fixes maxvit block
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Nov 2, 2023
1 parent 8f6aec4 commit 575703b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
34 changes: 27 additions & 7 deletions metnet/layers/MaxViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class MaxViTDataClass:
MBConv: Flag to denote downscaling in the conv branch, by default False
mb_conv_act_layer : Type[nn.Module], optional
MBConv: activation layer, by default nn.GELU
mb_conv_norm_layer : Type[nn.Module], optional
MBConv: norm layer, by default nn.BatchNorm2D
mb_conv_drop_path : float, optional
MBConv: Stochastic Depth ratio, by default 0.0
mb_conv_kernel_size : int, optional
Expand All @@ -32,6 +34,9 @@ class MaxViTDataClass:
MBConv: Squeeze Excite reduction ratio, by default 0.25
block_attention_num_heads : int, optional
BlockAttention: Number of attention heads, by default 32
block_attention_channels : int
BlockAttention: Number of channels used for attention computations.
It should be divisible by num_heads, by default 64
block_attention_attn_grid_window_size : Tuple[int, int], optional
BlockAttention: Grid/Window size for attention, by default (8, 8)
block_attention_attn_drop : float, optional
Expand All @@ -50,6 +55,9 @@ class MaxViTDataClass:
BlockAttention: Normalise queries and keys as done in Metnet 3, by default True
grid_attention_num_heads : int, optional
GridAttention: Number of attention heads, by default 32
grid_attention_channels : int
GridAttention: Number of channels used for attention computations.
It should be divisible by num_heads, by default 64
grid_attention_attn_grid_window_size : Tuple[int, int], optional
GridAttention: Grid/Window size for attention, by default (8, 8)
grid_attention_attn_drop : float, optional
Expand All @@ -71,10 +79,12 @@ class MaxViTDataClass:
mb_conv_expansion_rate: int = 4
mb_conv_downscale: bool = False
mb_conv_act_layer: Type[nn.Module] = nn.GELU
mb_conv_norm_layer: Type[nn.Module] = nn.BatchNorm2d
mb_conv_drop_path: float = 0.0
mb_conv_kernel_size: int = 3
mb_conv_se_bottleneck_ratio: float = 0.25
block_attention_num_heads: int = 32
block_attention_channels: int = 64
block_attention_attn_grid_window_size: Tuple[int, int] = (8, 8)
block_attention_attn_drop: float = 0
block_attention_proj_drop: float = 0
Expand All @@ -84,6 +94,7 @@ class MaxViTDataClass:
block_attention_mlp: Type[nn.Module] = None
block_attention_use_normalised_qk: bool = True
grid_attention_num_heads: int = 32
grid_attention_channels: int = 64
grid_attention_attn_grid_window_size: Tuple[int, int] = (8, 8)
grid_attention_attn_drop: float = 0
grid_attention_proj_drop: float = 0
Expand Down Expand Up @@ -111,13 +122,15 @@ def __init__(self, in_channels: int, maxvit_config: Type[MaxViTDataClass]) -> No
MaxVit Config
"""
super().__init__()
self.in_channels = in_channels
self.config = maxvit_config
mb_conv_out_channels = in_channels * self.config.mb_conv_expansion_rate

self.mb_conv = MBConv(
in_channels=self.config.in_channels,
in_channels=self.in_channels,
expansion_rate=self.config.mb_conv_expansion_rate,
downscale=self.config.mb_conv_downscale,
act_layer=self.config.mb_conv_act_layer,
norm_layer=self.config.mb_conv_norm_layer,
drop_path=self.config.mb_conv_drop_path,
kernel_size=self.config.mb_conv_kernel_size,
se_bottleneck_ratio=self.config.mb_conv_se_bottleneck_ratio,
Expand All @@ -126,8 +139,9 @@ def __init__(self, in_channels: int, maxvit_config: Type[MaxViTDataClass]) -> No
# Init Block and Grid Attention

self.block_attention = BlockAttention(
in_channels=mb_conv_out_channels,
in_channels=self.in_channels,
num_heads=self.config.block_attention_num_heads,
attention_channels=self.config.block_attention_channels,
attn_grid_window_size=self.config.block_attention_attn_grid_window_size,
attn_drop=self.config.block_attention_attn_drop,
proj_drop=self.config.block_attention_proj_drop,
Expand All @@ -139,8 +153,9 @@ def __init__(self, in_channels: int, maxvit_config: Type[MaxViTDataClass]) -> No
)

self.grid_attention = GridAttention(
in_channels=mb_conv_out_channels,
in_channels=self.in_channels,
num_heads=self.config.grid_attention_num_heads,
attention_channels=self.config.grid_attention_channels,
attn_grid_window_size=self.config.grid_attention_attn_grid_window_size,
attn_drop=self.config.grid_attention_attn_drop,
proj_drop=self.config.grid_attention_proj_drop,
Expand All @@ -158,14 +173,19 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
Parameters
----------
X : torch.Tensor
Input tensor of the shape [B, C_in, H, W]
Input tensor of the shape [N, C, H, W]
Returns:
-------
torch.Tensor
Output tensor of the shape [B, C_out, H // 2, W // 2] (downscaling is optional)
MaxViT block output tensor of the shape [N, C, H, W]
"""
output = self.grid_attention(self.block_attention(self.mb_conv(X)))

output = self.mb_conv(X)
output = self.block_attention(output)
output = self.grid_attention(output)

# output = self.grid_attention(self.block_attention(self.mb_conv(X)))
return output


Expand Down
15 changes: 7 additions & 8 deletions metnet/layers/MultiheadSelfAttention2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def __init__(
self.attn_drop = nn.Dropout(p=attn_drop)
self.proj_drop = nn.Dropout(p=proj_drop)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the multi-head self-attention mechanism.
Parameters
----------
x : torch.Tensor
X : torch.Tensor
Input tensor of shape (N, C, H, W).
Returns:
Expand All @@ -85,12 +85,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Output tensor after multi-head self-attention of shape (N, C, H, W).
"""

N, C, H, W = x.size()

N, C, H, W = X.size()
# Compute Q, K, V
Q = self.query(x)
K = self.key(x)
V = self.value(x)
Q = self.query(X)
K = self.key(X)
V = self.value(X)

Q = Q.view(N, self.num_heads, self.attention_head_size, H * W)
K = K.view(N, self.num_heads, self.attention_head_size, H * W)
Expand Down Expand Up @@ -130,4 +129,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Out shape [N, attention_channels, H, W]
out = self.out_proj(out) # Out shape [N, C, H, W]
out = self.proj_drop(out)
return out + x
return out + X
9 changes: 9 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from metnet.layers.MaxViT import MaxViTBlock, MaxViTDataClass
from metnet.layers.StochasticDepth import StochasticDepth
from metnet.layers.SqueezeExcitation import SqueezeExcite
from metnet.layers.MBConv import MBConv
Expand Down Expand Up @@ -53,3 +54,11 @@ def test_grid_attention():
grid_attention = GridAttention(c)

assert test_tensor.shape == grid_attention(test_tensor).shape


def test_maxvitblock():
n, c, h, w = 1, 3, 16, 16
test_tensor = torch.rand(n, c, h, w)

maxvit_block = MaxViTBlock(in_channels=c, maxvit_config=MaxViTDataClass())
assert test_tensor.shape == maxvit_block(test_tensor).shape

0 comments on commit 575703b

Please sign in to comment.