Skip to content

Commit

Permalink
mbconv fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 25, 2023
1 parent 82a9b08 commit 4627b62
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions metnet/layers/MBConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ def __init__(
kernel_size=3,
se_bottleneck_ratio=0.25,
):
# TODO: Verify implemtnetation
super().__init__()
self.in_channels = in_channels
self.drop_path_rate = drop_path
self.expansion_rate = expansion_rate
self.downscale = downscale
expanded_channels = self.in_channels * self.expansion_rate

self.conv_se_branch = nn.Sequential(
Expand All @@ -34,11 +36,13 @@ def __init__(
expanded_channels,
expanded_channels,
kernel_size,
stride=1,
stride=2 if self.downscale else 1,
groups=expanded_channels,
),
nn.LayerNorm(expanded_channels), # Norm2
SqueezeExcite(in_chs=expanded_channels, rd_ratio=se_bottleneck_ratio),
SqueezeExcite(
in_chs=expanded_channels, act_layer=act_layer, rd_ratio=se_bottleneck_ratio
),
nn.Conv2d( # Conv 1x1
in_channels=expanded_channels,
out_channels=expanded_channels,
Expand All @@ -50,6 +54,13 @@ def __init__(
self.stochastic_depth = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

self.skip_path = nn.Identity()
if self.downscale:
self.skip_path = nn.Sequential(
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
nn.Conv2d(
in_channels=in_channels, out_channels=expanded_channels, kernel_size=(1, 1)
),
)

def forward(self, X):
conv_se_output = self.conv_se_branch(X)
Expand Down

0 comments on commit 4627b62

Please sign in to comment.