Skip to content

Commit

Permalink
add the maxvit block for metnet
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 25, 2023
1 parent 644c830 commit 1385a2b
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions metnet/layers/MaxViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,79 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
output = self.grid_attention(self.block_attention(self.mb_conv(X)))
return output


def _met_net_maxvit_config():
return {
"mb_conv_expansion_rate": 4,
"mb_conv_downscale": False,
"mb_conv_act_layer": nn.GELU,
"mb_conv_drop_path": 0.0,
"mb_conv_kernel_size": 3,
"mb_conv_se_bottleneck_ratio": 0.25,
"block_attention_num_heads": 32,
"block_attention_attn_grid_window_size": (8, 8),
"block_attention_attn_drop": 0,
"block_attention_proj_drop": 0,
"block_attention_drop_path": 0,
"block_attention_pre_norm_layer": nn.LayerNorm,
"block_attention_post_norm_layer": nn.LayerNorm,
"block_attention_mlp": None,
"block_attention_use_normalised_qk": True,
"grid_attention_num_heads": 32,
"grid_attention_attn_grid_window_size": (8, 8),
"grid_attention_attn_drop": 0,
"grid_attention_proj_drop": 0,
"grid_attention_drop_path": 0,
"grid_attention_pre_norm_layer": nn.LayerNorm,
"grid_attention_post_norm_layer": nn.LayerNorm,
"grid_attention_mlp": None,
"grid_attention_use_normalised_qk": True,
}


class MetNetMaxVit(nn.Module):
def __init__(
self,
in_channels: int = 512,
out_channels: int = 512,
num_blocks: int = 12,
use_metnet_paper_conf: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.num_blocks = num_blocks
if not use_metnet_paper_conf:
# TODO: Make this configurable. Should I make a data class for these?
raise NotImplementedError(
"Currently only Metnet3 paper specified configs are supported"
)
maxvit_conf = _met_net_maxvit_config()
self.maxvit_blocks = nn.ModuleList()
for _ in range(self.num_blocks):
self.maxvit_blocks.append(MaxViTBlock(in_channels=self.in_channels, **maxvit_conf))

self.linear_transform = nn.Linear(in_features=out_channels, out_features=in_channels)

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Forward method
Parameters
----------
X : torch.Tensor
Input tensor
Returns:
-------
torch.Tensor
Output of the MaxViT block.
"""
model_output_list = []
model_output_list.append(self.maxvit_blocks[0](X))
for i in range(1, self.num_blocks):
model_output_list.append(self.maxvit_blocks[i](model_output_list[i - 1]))

output = X + torch.stack(model_output_list).sum(dim=0) # TODO: verify dim
output = self.linear_transform(output)
return output

0 comments on commit 1385a2b

Please sign in to comment.