diff --git a/metnet/layers/MBConv.py b/metnet/layers/MBConv.py index 3407ea7..97a801e 100644 --- a/metnet/layers/MBConv.py +++ b/metnet/layers/MBConv.py @@ -21,6 +21,7 @@ def __init__( expansion_rate: int = 4, downscale: bool = False, act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, drop_path: float = 0.0, kernel_size: int = 3, se_bottleneck_ratio: float = 0.25, @@ -36,8 +37,14 @@ def __init__( Expansion rate for the output channels, by default 4 downscale : bool, optional Flag to denote downscaling in the conv branch, by default False + Currently not implemented, as not specified in Metnet 3 act_layer : Type[nn.Module], optional activation layer, by default nn.GELU + norm_layer : Type[nn.Module], optional + normalisation layer, by default nn.BatchNorm2d + TODO: Verify if Layer Norm is to to be used inside MBConv + NOTE: Most implementations use nn.BatchNorm2d. If LayerNorm is to be + used, the intermediate h and w would need to be computed. drop_path : float, optional Stochastic Depth ratio, by default 0.0 kernel_size : int, optional @@ -51,30 +58,45 @@ def __init__( self.drop_path_rate = drop_path self.expansion_rate = expansion_rate self.downscale = downscale + out_channels = self.in_channels * self.expansion_rate + conv_dw_padding = ((kernel_size - 1) // 2, (kernel_size - 1) // 2) + + self.pre_norm = norm_layer(in_channels) + self.norm_layer_1 = norm_layer(out_channels) + self.norm_layer_2 = norm_layer(out_channels) + + if self.downscale: + # TODO: Check if downscaling is needed at all. May impact layer normalisation. + raise NotImplementedError( + "Downscaling in MBConv hasn't been implemented as it \ + isnt used in Metnet3" + ) - self.conv_se_branch = nn.Sequential( - nn.LayerNorm(in_channels), # Pre Norm + self.main_branch = nn.Sequential( + self.pre_norm, # Pre Normalize over the last three dimensions (i.e. the channel and spatial dimensions) # noqa nn.Conv2d( # Conv 1x1 in_channels=self.in_channels, out_channels=out_channels, kernel_size=1, ), - nn.LayerNorm(out_channels), # Norm1 + self.norm_layer_1, # Norm1 nn.Conv2d( # Depth wise Conv kxk out_channels, out_channels, kernel_size, - stride=2 if self.downscale else 1, + padding=conv_dw_padding, # To maintain shapes groups=out_channels, ), - nn.LayerNorm(out_channels), # Norm2 + self.norm_layer_2, # Norm2 SqueezeExcite( - in_channels=out_channels, act_layer=act_layer, rd_ratio=se_bottleneck_ratio + in_channels=out_channels, + act_layer=act_layer, + rd_ratio=se_bottleneck_ratio, ), nn.Conv2d( # Conv 1x1 in_channels=out_channels, - out_channels=out_channels, + out_channels=in_channels, kernel_size=1, ), # No Norm as this is the last convolution layer in this block @@ -82,13 +104,6 @@ def __init__( self.stochastic_depth = StochasticDepth(drop_path) if drop_path > 0.0 else nn.Identity() - self.skip_path = nn.Identity() - if self.downscale: - self.skip_path = nn.Sequential( - nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1)), - ) - def forward(self, X: torch.Tensor) -> torch.Tensor: """ Forward step @@ -96,13 +111,11 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: Parameters ---------- X : torch.Tensor - Input Tensor + Input Tensor of shape [N, C, H, W] Returns: ------- torch.Tensor MBConv output """ - conv_se_output = self.conv_se_branch(X) - conv_se_output = self.stochastic_depth(conv_se_output) - return conv_se_output + self.skip_path(X) + return X + self.stochastic_depth(self.main_branch(X)) diff --git a/tests/test_layers.py b/tests/test_layers.py new file mode 100644 index 0000000..83ca449 --- /dev/null +++ b/tests/test_layers.py @@ -0,0 +1,30 @@ +from metnet.layers.StochasticDepth import StochasticDepth +from metnet.layers.SqueezeExcitation import SqueezeExcite +from metnet.layers.MBConv import MBConv +import torch + + +def test_stochastic_depth(): + test_tensor = torch.ones(1) + + stochastic_depth = StochasticDepth(drop_prob=0) + assert test_tensor == stochastic_depth(test_tensor) + + stochastic_depth = StochasticDepth(drop_prob=1) + assert torch.zeros_like(test_tensor) == stochastic_depth(test_tensor) + + +def test_squeeze_excitation(): + n, c, h, w = 1, 3, 16, 16 + test_tensor = torch.rand(n, c, h, w) + + squeeze_excite = SqueezeExcite(in_channels=c) + assert test_tensor.shape == squeeze_excite(test_tensor).shape + + +def test_mbconv(): + n, c, h, w = 1, 3, 16, 16 + test_tensor = torch.rand(n, c, h, w) + mb_conv = MBConv(c) + + assert test_tensor.shape == mb_conv(test_tensor).shape