Skip to content

Commit

Permalink
adds data class and more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 27, 2023
1 parent 1e3dc2b commit 8ff1116
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 97 deletions.
261 changes: 165 additions & 96 deletions metnet/layers/MaxViT.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,147 @@
from dataclasses import dataclass
from typing import List, Tuple, Type, Union

import torch
from torch import nn

from metnet.layers.MBConv import MBConv
from metnet.layers.PartitionAttention import BlockAttention, GridAttention


@dataclass
class MaxViTDataClass:
"""
DataClass for MaxViT
Parameters
----------
mb_conv_expansion_rate : int, optional
MBConv: Expansion rate for the output channels, by default 4
mb_conv_downscale : bool, optional
MBConv: Flag to denote downscaling in the conv branch, by default False
mb_conv_act_layer : Type[nn.Module], optional
MBConv: activation layer, by default nn.GELU
mb_conv_drop_path : float, optional
MBConv: Stochastic Depth ratio, by default 0.0
mb_conv_kernel_size : int, optional
MBConv: Conv kernel size, by default 3
mb_conv_se_bottleneck_ratio : float, optional
MBConv: Squeeze Excite reduction ratio, by default 0.25
block_attention_num_heads : int, optional
BlockAttention: Number of attention heads, by default 32
block_attention_attn_grid_window_size : Tuple[int, int], optional
BlockAttention: Grid/Window size for attention, by default (8, 8)
block_attention_attn_drop : float, optional
BlockAttention: Dropout ratio of attention weight, by default 0
block_attention_proj_drop : float, optional
BlockAttention: Dropout ratio of output, by default 0
block_attention_drop_path : float, optional
BlockAttention: Stochastic depth, by default 0
block_attention_pre_norm_layer : Type[nn.Module], optional
BlockAttention: Pre norm layer, by default nn.LayerNorm
block_attention_post_norm_layer : Type[nn.Module], optional
BlockAttention: Post norm layer, by default nn.LayerNorm
block_attention_mlp : Type[nn.Module], optional
BlockAttention: MLP to be used after the attention, by default None
block_attention_use_normalised_qk : bool, optional
BlockAttention: Normalise queries and keys as done in Metnet 3, by default True
grid_attention_num_heads : int, optional
GridAttention: Number of attention heads, by default 32
grid_attention_attn_grid_window_size : Tuple[int, int], optional
GridAttention: Grid/Window size for attention, by default (8, 8)
grid_attention_attn_drop : float, optional
GridAttention: Dropout ratio of attention weight, by default 0
grid_attention_proj_drop : float, optional
GridAttention: Dropout ratio of output, by default 0
grid_attention_drop_path : float, optional
GridAttention: Stochastic depth, by default 0
grid_attention_pre_norm_layer : Type[nn.Module], optional
GridAttention: Pre norm layer, by default nn.LayerNorm
grid_attention_post_norm_layer : Type[nn.Module], optional
GridAttention: Post norm layer, by default nn.LayerNorm
grid_attention_mlp : Type[nn.Module], optional
GridAttention: MLP to be used after the attention, by default None
grid_attention_use_normalised_qk : bool, optional
GridAttention: Normalise queries and keys as done in Metnet 3, by default True
"""

mb_conv_expansion_rate: int = 4
mb_conv_downscale: bool = False
mb_conv_act_layer: Type[nn.Module] = nn.GELU
mb_conv_drop_path: float = 0.0
mb_conv_kernel_size: int = 3
mb_conv_se_bottleneck_ratio: float = 0.25
block_attention_num_heads: int = 32
block_attention_attn_grid_window_size: Tuple[int, int] = (8, 8)
block_attention_attn_drop: float = 0
block_attention_proj_drop: float = 0
block_attention_drop_path: float = 0
block_attention_pre_norm_layer: Type[nn.Module] = nn.LayerNorm
block_attention_post_norm_layer: Type[nn.Module] = nn.LayerNorm
block_attention_mlp: Type[nn.Module] = None
block_attention_use_normalised_qk: bool = True
grid_attention_num_heads: int = 32
grid_attention_attn_grid_window_size: Tuple[int, int] = (8, 8)
grid_attention_attn_drop: float = 0
grid_attention_proj_drop: float = 0
grid_attention_drop_path: float = 0
grid_attention_pre_norm_layer: Type[nn.Module] = nn.LayerNorm
grid_attention_post_norm_layer: Type[nn.Module] = nn.LayerNorm
grid_attention_mlp: Type[nn.Module] = None
grid_attention_use_normalised_qk: bool = True


class MaxViTBlock(nn.Module):
def __init__(
self,
in_channels,
mb_conv_expansion_rate=4,
mb_conv_downscale=False,
mb_conv_act_layer=nn.GELU,
mb_conv_drop_path=0.0,
mb_conv_kernel_size=3,
mb_conv_se_bottleneck_ratio=0.25,
block_attention_num_heads=32,
block_attention_attn_grid_window_size=(8, 8),
block_attention_attn_drop=0,
block_attention_proj_drop=0,
block_attention_drop_path=0,
block_attention_pre_norm_layer=nn.LayerNorm,
block_attention_post_norm_layer=nn.LayerNorm,
block_attention_mlp=None,
block_attention_use_normalised_qk=True,
grid_attention_num_heads=32,
grid_attention_attn_grid_window_size=(8, 8),
grid_attention_attn_drop=0,
grid_attention_proj_drop=0,
grid_attention_drop_path=0,
grid_attention_pre_norm_layer=nn.LayerNorm,
grid_attention_post_norm_layer=nn.LayerNorm,
grid_attention_mlp=None,
grid_attention_use_normalised_qk=True,
) -> None:
"""Constructor method"""
# Call super constructor
super().__init__()
def __init__(self, in_channels: int, maxvit_config: Type[MaxViTDataClass]) -> None:
"""
MaxViT block
mb_conv_out_channels = in_channels * mb_conv_expansion_rate
Parameters
----------
in_channels : int
Number of input channels
maxvit_config : Type[MaxViTDataClass]
MaxVit Config
"""
super().__init__()
self.config = maxvit_config
mb_conv_out_channels = in_channels * self.config.mb_conv_expansion_rate
self.mb_conv = MBConv(
in_channels=in_channels,
expansion_rate=mb_conv_expansion_rate,
downscale=mb_conv_downscale,
act_layer=mb_conv_act_layer,
drop_path=mb_conv_drop_path,
kernel_size=mb_conv_kernel_size,
se_bottleneck_ratio=mb_conv_se_bottleneck_ratio,
in_channels=self.config.in_channels,
expansion_rate=self.config.mb_conv_expansion_rate,
downscale=self.config.mb_conv_downscale,
act_layer=self.config.mb_conv_act_layer,
drop_path=self.config.mb_conv_drop_path,
kernel_size=self.config.mb_conv_kernel_size,
se_bottleneck_ratio=self.config.mb_conv_se_bottleneck_ratio,
)

# Init Block and Grid Attention

self.block_attention = BlockAttention(
in_channels=mb_conv_out_channels,
num_heads=block_attention_num_heads,
attn_grid_window_size=block_attention_attn_grid_window_size,
attn_drop=block_attention_attn_drop,
proj_drop=block_attention_proj_drop,
drop_path=block_attention_drop_path,
pre_norm_layer=block_attention_pre_norm_layer,
post_norm_layer=block_attention_post_norm_layer,
mlp=block_attention_mlp,
use_normalised_qk=block_attention_use_normalised_qk,
num_heads=self.config.block_attention_num_heads,
attn_grid_window_size=self.config.block_attention_attn_grid_window_size,
attn_drop=self.config.block_attention_attn_drop,
proj_drop=self.config.block_attention_proj_drop,
drop_path=self.config.block_attention_drop_path,
pre_norm_layer=self.config.block_attention_pre_norm_layer,
post_norm_layer=self.config.block_attention_post_norm_layer,
mlp=self.config.block_attention_mlp,
use_normalised_qk=self.config.block_attention_use_normalised_qk,
)

self.grid_attention = GridAttention(
in_channels=mb_conv_out_channels,
num_heads=grid_attention_num_heads,
attn_grid_window_size=grid_attention_attn_grid_window_size,
attn_drop=grid_attention_attn_drop,
proj_drop=grid_attention_proj_drop,
drop_path=grid_attention_drop_path,
pre_norm_layer=grid_attention_pre_norm_layer,
post_norm_layer=grid_attention_post_norm_layer,
mlp=grid_attention_mlp,
use_normalised_qk=grid_attention_use_normalised_qk,
num_heads=self.config.grid_attention_num_heads,
attn_grid_window_size=self.config.grid_attention_attn_grid_window_size,
attn_drop=self.config.grid_attention_attn_drop,
proj_drop=self.config.grid_attention_proj_drop,
drop_path=self.config.grid_attention_drop_path,
pre_norm_layer=self.config.grid_attention_pre_norm_layer,
post_norm_layer=self.config.grid_attention_post_norm_layer,
mlp=self.config.grid_attention_mlp,
use_normalised_qk=self.config.grid_attention_use_normalised_qk,
)

def forward(self, X: torch.Tensor) -> torch.Tensor:
Expand All @@ -90,57 +157,59 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
return output


def _met_net_maxvit_config():
return {
"mb_conv_expansion_rate": 4,
"mb_conv_downscale": False,
"mb_conv_act_layer": nn.GELU,
"mb_conv_drop_path": 0.0,
"mb_conv_kernel_size": 3,
"mb_conv_se_bottleneck_ratio": 0.25,
"block_attention_num_heads": 32,
"block_attention_attn_grid_window_size": (8, 8),
"block_attention_attn_drop": 0,
"block_attention_proj_drop": 0,
"block_attention_drop_path": 0,
"block_attention_pre_norm_layer": nn.LayerNorm,
"block_attention_post_norm_layer": nn.LayerNorm,
"block_attention_mlp": None,
"block_attention_use_normalised_qk": True,
"grid_attention_num_heads": 32,
"grid_attention_attn_grid_window_size": (8, 8),
"grid_attention_attn_drop": 0,
"grid_attention_proj_drop": 0,
"grid_attention_drop_path": 0,
"grid_attention_pre_norm_layer": nn.LayerNorm,
"grid_attention_post_norm_layer": nn.LayerNorm,
"grid_attention_mlp": None,
"grid_attention_use_normalised_qk": True,
}


class MetNetMaxVit(nn.Module):
def __init__(
self,
in_channels: int = 512,
out_channels: int = 512,
num_blocks: int = 12,
use_metnet_paper_conf: bool = True,
maxvit_conf: Union[Type[MaxViTDataClass], List[Type[MaxViTDataClass]]] = MaxViTDataClass(),
set_linear_stocastic_depth: bool = True,
) -> None:
"""
MetNet3 MaxViT Block
Parameters
----------
in_channels : int, optional
Input Channels, by default 512
out_channels : int, optional
Output Channels, by default 512
num_blocks : int, optional
Number of MaxViT blocks, by default 12
maxvit_conf : Union[ Type[MaxViTDataClass], List[Type[MaxViTDataClass]] ], optional
MaxViT config, by default MaxViTDataClass()
set_linear_stocastic_depth : bool, optional
Flag to set the stochastic depth linearly in each MaxVit subblock, by default True
"""
super().__init__()
self.in_channels = in_channels
self.num_blocks = num_blocks
if not use_metnet_paper_conf:
# TODO: Make this configurable. Should I make a data class for these?
raise NotImplementedError(
"Currently only Metnet3 paper specified configs are supported"
)
maxvit_conf = _met_net_maxvit_config()
self.set_linear_stocastic_depth = set_linear_stocastic_depth

self.maxvit_blocks = nn.ModuleList()
for _ in range(self.num_blocks):
self.maxvit_blocks.append(MaxViTBlock(in_channels=self.in_channels, **maxvit_conf))

self.linear_transform = nn.Linear(in_features=out_channels, out_features=in_channels)
if isinstance(maxvit_conf, List):
assert len(maxvit_conf) == num_blocks
self.maxvit_conf_list = maxvit_conf
else:
self.maxvit_conf_list = [maxvit_conf for _ in range(self.num_blocks)]

if self.set_linear_stocastic_depth:
# Linearly sets the stochastic depth a given sub-module
# (i.e. MBConv, local (block) attention or gridded (grid) attention)
# from 0 to 0.2, as mentioned in Metnet3 paper
for conf in self.maxvit_conf_list:
conf.mb_conv_drop_path = 0
conf.block_attention_drop_path = 0.1
conf.grid_attention_drop_path = 0.2

for conf in self.maxvit_conf_list:
self.maxvit_blocks.append(MaxViTBlock(in_channels=self.in_channels, maxvit_config=conf))

self.linear_transform = nn.Linear(
in_features=out_channels, out_features=in_channels
) # TODO:Test the shapes

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion metnet/layers/PartitionAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(
num_heads : int, optional
Number of attention heads, by default 32
attn_grid_window_size : Tuple[int, int], optional
Grid/Window size to be utilized, by default (8, 8)
Grid/Window size for attention, by default (8, 8)
attn_drop : float, optional
Dropout ratio of attention weight, by default 0.0
proj_drop : float, optional
Expand Down

0 comments on commit 8ff1116

Please sign in to comment.