Skip to content

Commit

Permalink
adding MBConv
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 22, 2023
1 parent 41de6f6 commit c8fe101
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions metnet/layers/MBConv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from torch import nn
from timm.layers import DropPath
from timm.models.efficientnet_blocks import SqueezeExcite


class MBConv(nn.Module):
def __init__(
self,
in_channels: int,
expansion_rate: int = 4,
downscale: bool = False, # TODO
act_layer: Type[nn.Module] = nn.GELU,
drop_path: float = 0.0,
kernel_size=3,
se_bottleneck_ratio=0.25,
):
super().__init__()
self.in_channels = in_channels
self.drop_path_rate = drop_path
self.expansion_rate = expansion_rate
expanded_channels = self.in_channels * self.expansion_rate

self.conv_se_branch = nn.Sequential(
nn.LayerNorm(in_channels), # Pre Norm
nn.Conv2d( # Conv 1x1
in_channels=self.in_channels,
out_channels=expanded_channels,
kernel_size=1,
),
nn.LayerNorm(expanded_channels), # Norm1
nn.Conv2d( # Depth wise Conv kxk
expanded_channels,
expanded_channels,
kernel_size,
stride=1,
groups=expanded_channels,
),
nn.LayerNorm(expanded_channels), # Norm2
SqueezeExcite(in_chs=expanded_channels, rd_ratio=se_bottleneck_ratio),
nn.Conv2d( # Conv 1x1
in_channels=expanded_channels,
out_channels=expanded_channels,
kernel_size=1,
),
# No Norm as this is the last convolution layer in this block
)

self.stochastic_depth = (
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)

self.skip_path = nn.Identity()

def forward(self, X):
conv_se_output = self.conv_se_branch(X)
conv_se_output = self.stochastic_depth(conv_se_output)
output = conv_se_output + self.skip_path(X)

0 comments on commit c8fe101

Please sign in to comment.