Skip to content

Commit

Permalink
testing MB Conv
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 29, 2023
1 parent 2dce68d commit 3fd9b91
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 18 deletions.
49 changes: 31 additions & 18 deletions metnet/layers/MBConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
expansion_rate: int = 4,
downscale: bool = False,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
drop_path: float = 0.0,
kernel_size: int = 3,
se_bottleneck_ratio: float = 0.25,
Expand All @@ -36,8 +37,14 @@ def __init__(
Expansion rate for the output channels, by default 4
downscale : bool, optional
Flag to denote downscaling in the conv branch, by default False
Currently not implemented, as not specified in Metnet 3
act_layer : Type[nn.Module], optional
activation layer, by default nn.GELU
norm_layer : Type[nn.Module], optional
normalisation layer, by default nn.BatchNorm2d
TODO: Verify if Layer Norm is to to be used inside MBConv
NOTE: Most implementations use nn.BatchNorm2d. If LayerNorm is to be
used, the intermediate h and w would need to be computed.
drop_path : float, optional
Stochastic Depth ratio, by default 0.0
kernel_size : int, optional
Expand All @@ -51,58 +58,64 @@ def __init__(
self.drop_path_rate = drop_path
self.expansion_rate = expansion_rate
self.downscale = downscale

out_channels = self.in_channels * self.expansion_rate
conv_dw_padding = ((kernel_size - 1) // 2, (kernel_size - 1) // 2)

self.pre_norm = norm_layer(in_channels)
self.norm_layer_1 = norm_layer(out_channels)
self.norm_layer_2 = norm_layer(out_channels)

if self.downscale:
# TODO: Check if downscaling is needed at all. May impact layer normalisation.
raise NotImplementedError(
"Downscaling in MBConv hasn't been implemented as it \
isnt used in Metnet3"
)

self.conv_se_branch = nn.Sequential(
nn.LayerNorm(in_channels), # Pre Norm
self.main_branch = nn.Sequential(
self.pre_norm, # Pre Normalize over the last three dimensions (i.e. the channel and spatial dimensions) # noqa
nn.Conv2d( # Conv 1x1
in_channels=self.in_channels,
out_channels=out_channels,
kernel_size=1,
),
nn.LayerNorm(out_channels), # Norm1
self.norm_layer_1, # Norm1
nn.Conv2d( # Depth wise Conv kxk
out_channels,
out_channels,
kernel_size,
stride=2 if self.downscale else 1,
padding=conv_dw_padding, # To maintain shapes
groups=out_channels,
),
nn.LayerNorm(out_channels), # Norm2
self.norm_layer_2, # Norm2
SqueezeExcite(
in_channels=out_channels, act_layer=act_layer, rd_ratio=se_bottleneck_ratio
in_channels=out_channels,
act_layer=act_layer,
rd_ratio=se_bottleneck_ratio,
),
nn.Conv2d( # Conv 1x1
in_channels=out_channels,
out_channels=out_channels,
out_channels=in_channels,
kernel_size=1,
),
# No Norm as this is the last convolution layer in this block
)

self.stochastic_depth = StochasticDepth(drop_path) if drop_path > 0.0 else nn.Identity()

self.skip_path = nn.Identity()
if self.downscale:
self.skip_path = nn.Sequential(
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1)),
)

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Forward step
Parameters
----------
X : torch.Tensor
Input Tensor
Input Tensor of shape [N, C, H, W]
Returns:
-------
torch.Tensor
MBConv output
"""
conv_se_output = self.conv_se_branch(X)
conv_se_output = self.stochastic_depth(conv_se_output)
return conv_se_output + self.skip_path(X)
return X + self.stochastic_depth(self.main_branch(X))
30 changes: 30 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from metnet.layers.StochasticDepth import StochasticDepth
from metnet.layers.SqueezeExcitation import SqueezeExcite
from metnet.layers.MBConv import MBConv
import torch


def test_stochastic_depth():
test_tensor = torch.ones(1)

stochastic_depth = StochasticDepth(drop_prob=0)
assert test_tensor == stochastic_depth(test_tensor)

stochastic_depth = StochasticDepth(drop_prob=1)
assert torch.zeros_like(test_tensor) == stochastic_depth(test_tensor)


def test_squeeze_excitation():
n, c, h, w = 1, 3, 16, 16
test_tensor = torch.rand(n, c, h, w)

squeeze_excite = SqueezeExcite(in_channels=c)
assert test_tensor.shape == squeeze_excite(test_tensor).shape


def test_mbconv():
n, c, h, w = 1, 3, 16, 16
test_tensor = torch.rand(n, c, h, w)
mb_conv = MBConv(c)

assert test_tensor.shape == mb_conv(test_tensor).shape

0 comments on commit 3fd9b91

Please sign in to comment.