diff --git a/metnet/layers/MBConv.py b/metnet/layers/MBConv.py new file mode 100644 index 0000000..b02e476 --- /dev/null +++ b/metnet/layers/MBConv.py @@ -0,0 +1,57 @@ +from torch import nn +from timm.layers import DropPath +from timm.models.efficientnet_blocks import SqueezeExcite + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + expansion_rate: int = 4, + downscale: bool = False, # TODO + act_layer: Type[nn.Module] = nn.GELU, + drop_path: float = 0.0, + kernel_size=3, + se_bottleneck_ratio=0.25, + ): + super().__init__() + self.in_channels = in_channels + self.drop_path_rate = drop_path + self.expansion_rate = expansion_rate + expanded_channels = self.in_channels * self.expansion_rate + + self.conv_se_branch = nn.Sequential( + nn.LayerNorm(in_channels), # Pre Norm + nn.Conv2d( # Conv 1x1 + in_channels=self.in_channels, + out_channels=expanded_channels, + kernel_size=1, + ), + nn.LayerNorm(expanded_channels), # Norm1 + nn.Conv2d( # Depth wise Conv kxk + expanded_channels, + expanded_channels, + kernel_size, + stride=1, + groups=expanded_channels, + ), + nn.LayerNorm(expanded_channels), # Norm2 + SqueezeExcite(in_chs=expanded_channels, rd_ratio=se_bottleneck_ratio), + nn.Conv2d( # Conv 1x1 + in_channels=expanded_channels, + out_channels=expanded_channels, + kernel_size=1, + ), + # No Norm as this is the last convolution layer in this block + ) + + self.stochastic_depth = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + + self.skip_path = nn.Identity() + + def forward(self, X): + conv_se_output = self.conv_se_branch(X) + conv_se_output = self.stochastic_depth(conv_se_output) + output = conv_se_output + self.skip_path(X)