Skip to content

Commit

Permalink
fixing mlps and norms in attnetion
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 25, 2023
1 parent 4627b62 commit 24aab24
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion metnet/layers/PartitionAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __init__(
attn_drop: float = 0.0,
proj_drop: float = 0.0,
drop_path: float = 0.0,
pre_norm_layer: Type[nn.Module] = nn.LayerNorm,
post_norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp: Type[nn.Module] = None,
use_normalised_qk: bool = True,
) -> None:
"""
Expand Down Expand Up @@ -62,6 +65,14 @@ def __init__(
proj_drop=proj_drop,
use_normalised_qk=use_normalised_qk,
)
self.pre_norm_layer = pre_norm_layer(in_channels)
self.post_norm_layer = post_norm_layer(in_channels)

if mlp:
# TODO: allow for an mlp to be passed here
raise NotImplementedError("Metnet 3 does noes use MLPs in MaxVit.")
else:
self.mlp = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

def partition_function(self, X: torch.Tensor):
Expand Down Expand Up @@ -125,7 +136,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
input_partitioned = input_partitioned.view(
-1, self.attn_grid_window_size[0] * self.attn_grid_window_size[1], C
)
output = input_partitioned + self.drop_path(self.attention(self.norm_1(input_partitioned)))
# Perform normalization, attention, and dropout
output = input_partitioned + self.drop_path(
self.attention(self.pre_norm_layer(input_partitioned))
)

# Perform normalization, MLP, and dropout
output = output + self.drop_path(self.mlp(self.post_norm_layer(output)))

# Reverse partition
output = self.reverse_function(output, (H, W))
return output
Expand Down

0 comments on commit 24aab24

Please sign in to comment.