Skip to content

Commit

Permalink
adding maxvit block
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 25, 2023
1 parent 1939dd8 commit 644c830
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions metnet/layers/MaxViT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from torch import nn

from metnet.layers.MBConv import MBConv
from metnet.layers.PartitionAttention import BlockAttention, GridAttention


class MaxViTBlock(nn.Module):
def __init__(
self,
in_channels,
mb_conv_expansion_rate=4,
mb_conv_downscale=False,
mb_conv_act_layer=nn.GELU,
mb_conv_drop_path=0.0,
mb_conv_kernel_size=3,
mb_conv_se_bottleneck_ratio=0.25,
block_attention_num_heads=32,
block_attention_attn_grid_window_size=(8, 8),
block_attention_attn_drop=0,
block_attention_proj_drop=0,
block_attention_drop_path=0,
block_attention_pre_norm_layer=nn.LayerNorm,
block_attention_post_norm_layer=nn.LayerNorm,
block_attention_mlp=None,
block_attention_use_normalised_qk=True,
grid_attention_num_heads=32,
grid_attention_attn_grid_window_size=(8, 8),
grid_attention_attn_drop=0,
grid_attention_proj_drop=0,
grid_attention_drop_path=0,
grid_attention_pre_norm_layer=nn.LayerNorm,
grid_attention_post_norm_layer=nn.LayerNorm,
grid_attention_mlp=None,
grid_attention_use_normalised_qk=True,
) -> None:
"""Constructor method"""
# Call super constructor
super().__init__()

mb_conv_out_channels = in_channels * mb_conv_expansion_rate
self.mb_conv = MBConv(
in_channels=in_channels,
expansion_rate=mb_conv_expansion_rate,
downscale=mb_conv_downscale,
act_layer=mb_conv_act_layer,
drop_path=mb_conv_drop_path,
kernel_size=mb_conv_kernel_size,
se_bottleneck_ratio=mb_conv_se_bottleneck_ratio,
)

# Init Block and Grid Attention

self.block_attention = BlockAttention(
in_channels=mb_conv_out_channels,
num_heads=block_attention_num_heads,
attn_grid_window_size=block_attention_attn_grid_window_size,
attn_drop=block_attention_attn_drop,
proj_drop=block_attention_proj_drop,
drop_path=block_attention_drop_path,
pre_norm_layer=block_attention_pre_norm_layer,
post_norm_layer=block_attention_post_norm_layer,
mlp=block_attention_mlp,
use_normalised_qk=block_attention_use_normalised_qk,
)

self.grid_attention = GridAttention(
in_channels=mb_conv_out_channels,
num_heads=grid_attention_num_heads,
attn_grid_window_size=grid_attention_attn_grid_window_size,
attn_drop=grid_attention_attn_drop,
proj_drop=grid_attention_proj_drop,
drop_path=grid_attention_drop_path,
pre_norm_layer=grid_attention_pre_norm_layer,
post_norm_layer=grid_attention_post_norm_layer,
mlp=grid_attention_mlp,
use_normalised_qk=grid_attention_use_normalised_qk,
)

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape [B, C_in, H, W]
Returns:
output (torch.Tensor): Output tensor of the shape [B, C_out, H // 2, W // 2] (downscaling is optional)
"""
output = self.grid_attention(self.block_attention(self.mb_conv(X)))
return output

0 comments on commit 644c830

Please sign in to comment.