diff --git a/metnet/layers/PartitionAttention.py b/metnet/layers/PartitionAttention.py index 5999364..1335ab2 100644 --- a/metnet/layers/PartitionAttention.py +++ b/metnet/layers/PartitionAttention.py @@ -16,6 +16,9 @@ def __init__( attn_drop: float = 0.0, proj_drop: float = 0.0, drop_path: float = 0.0, + pre_norm_layer: Type[nn.Module] = nn.LayerNorm, + post_norm_layer: Type[nn.Module] = nn.LayerNorm, + mlp: Type[nn.Module] = None, use_normalised_qk: bool = True, ) -> None: """ @@ -62,6 +65,14 @@ def __init__( proj_drop=proj_drop, use_normalised_qk=use_normalised_qk, ) + self.pre_norm_layer = pre_norm_layer(in_channels) + self.post_norm_layer = post_norm_layer(in_channels) + + if mlp: + # TODO: allow for an mlp to be passed here + raise NotImplementedError("Metnet 3 does noes use MLPs in MaxVit.") + else: + self.mlp = nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def partition_function(self, X: torch.Tensor): @@ -125,7 +136,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: input_partitioned = input_partitioned.view( -1, self.attn_grid_window_size[0] * self.attn_grid_window_size[1], C ) - output = input_partitioned + self.drop_path(self.attention(self.norm_1(input_partitioned))) + # Perform normalization, attention, and dropout + output = input_partitioned + self.drop_path( + self.attention(self.pre_norm_layer(input_partitioned)) + ) + + # Perform normalization, MLP, and dropout + output = output + self.drop_path(self.mlp(self.post_norm_layer(output))) + # Reverse partition output = self.reverse_function(output, (H, W)) return output