diff --git a/metnet/layers/MBConv.py b/metnet/layers/MBConv.py index c98fcad..86e99a4 100644 --- a/metnet/layers/MBConv.py +++ b/metnet/layers/MBConv.py @@ -16,10 +16,12 @@ def __init__( kernel_size=3, se_bottleneck_ratio=0.25, ): + # TODO: Verify implemtnetation super().__init__() self.in_channels = in_channels self.drop_path_rate = drop_path self.expansion_rate = expansion_rate + self.downscale = downscale expanded_channels = self.in_channels * self.expansion_rate self.conv_se_branch = nn.Sequential( @@ -34,11 +36,13 @@ def __init__( expanded_channels, expanded_channels, kernel_size, - stride=1, + stride=2 if self.downscale else 1, groups=expanded_channels, ), nn.LayerNorm(expanded_channels), # Norm2 - SqueezeExcite(in_chs=expanded_channels, rd_ratio=se_bottleneck_ratio), + SqueezeExcite( + in_chs=expanded_channels, act_layer=act_layer, rd_ratio=se_bottleneck_ratio + ), nn.Conv2d( # Conv 1x1 in_channels=expanded_channels, out_channels=expanded_channels, @@ -50,6 +54,13 @@ def __init__( self.stochastic_depth = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.skip_path = nn.Identity() + if self.downscale: + self.skip_path = nn.Sequential( + nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), + nn.Conv2d( + in_channels=in_channels, out_channels=expanded_channels, kernel_size=(1, 1) + ), + ) def forward(self, X): conv_se_output = self.conv_se_branch(X)