diff --git a/metnet/layers/SqueezeExcitation.py b/metnet/layers/SqueezeExcitation.py index 961ff67..f6b2b27 100644 --- a/metnet/layers/SqueezeExcitation.py +++ b/metnet/layers/SqueezeExcitation.py @@ -56,7 +56,7 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: torch.Tensor Output Tensor """ - x_se = X.mean((2, 3), keepdim=True) + x_se = X.mean((2, 3), keepdim=True) # Mean along H, W dim x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) diff --git a/metnet/layers/StochasticDepth.py b/metnet/layers/StochasticDepth.py index 5dfa3f0..30c2f98 100644 --- a/metnet/layers/StochasticDepth.py +++ b/metnet/layers/StochasticDepth.py @@ -22,7 +22,7 @@ def __init__(self, drop_prob: float = 0.0) -> None: probability to drop the network path, by default 0.0 """ super().__init__() - assert 0 < drop_prob < 1.0 + assert 0 <= drop_prob <= 1.0 self.drop_prob = drop_prob def forward(self, X: torch.Tensor) -> torch.Tensor: