From 63d9596ae24417ae5fbf37db30a4c3abb0ac0ba7 Mon Sep 17 00:00:00 2001 From: Raahul Singh Date: Tue, 24 Oct 2023 22:58:25 +0530 Subject: [PATCH] temp --- metnet/layers/PartitionAttention.py | 138 ++++++++++++++++++---------- 1 file changed, 87 insertions(+), 51 deletions(-) diff --git a/metnet/layers/PartitionAttention.py b/metnet/layers/PartitionAttention.py index f0ecd2e..cc86ea2 100644 --- a/metnet/layers/PartitionAttention.py +++ b/metnet/layers/PartitionAttention.py @@ -8,60 +8,78 @@ class PartitionAttention(nn.Module): - """PartitionAttention block. - - With block partition: - x ← x + Unblock(RelAttention(Block(LN(x)))) - x ← x + MLP(LN(x)) - - With grid partition: - x ← x + Ungrid(RelAttention(Grid(LN(x)))) - x ← x + MLP(LN(x)) - - Layer Normalization (LN) is applied after the grid/window partition to prevent multiple reshaping operations. - Grid/window reverse (Unblock/Ungrid) is performed on the final output for the same reason. - - Args: - in_channels (int): Number of input channels. - partition_function (Callable): Partition function to be utilized (grid or window partition). - reverse_function (Callable): Reverse function to be utilized (grid or window reverse). - num_heads (int, optional): Number of attention heads. Default 32 - partition_window_size (Tuple[int, int], optional): Grid/Window size to be utilized. Default (7, 7) - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - drop (float, optional): Dropout ratio of output. Default: 0.0 - drop_path (float, optional): Dropout ratio of path. Default: 0.0 - mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default: 4.0 - act_layer (Type[nn.Module], optional): Type of activation layer to be utilized. Default: nn.GELU - norm_layer (Type[nn.Module], optional): Type of normalization layer to be utilized. Default: nn.BatchNorm2d - """ def __init__( self, in_channels: int, num_heads: int = 32, - partition_window_size: Tuple[int, int] = (7, 7), + attn_grid_window_size: Tuple[int, int] = (8, 8), attn_drop: float = 0.0, - drop: float = 0.0, + proj_drop: float = 0.0, drop_path: float = 0.0, - norm_layer: Type[nn.Module] = nn.LayerNorm, + use_normalised_qk: bool = True ) -> None: - """Constructor method""" + """ + PartitionAttention block + Implements the common functionality for block and grid attention. + + With block partition: + x ← x + Unblock(RelAttention(Block(x))) + + With grid partition: + x ← x + Ungrid(RelAttention(Grid(x))) + + Parameters + ---------- + in_channels : int + Number of input channels. + num_heads : int, optional + Number of attention heads, by default 32 + attn_grid_window_size : Tuple[int, int], optional + Grid/Window size to be utilized, by default (8, 8) + attn_drop : float, optional + Dropout ratio of attention weight, by default 0.0 + proj_drop : float, optional + Dropout ratio of output, by default 0.0 + drop_path : float, optional + Stochastic depth, by default 0.0 + use_normalised_qk : bool, optional + Normalise queries and keys as done in Metnet 3, by default True. + + Notes + ----- + Specific to Metnet 3 implementation + TODO: Add the MLP as an optional parameter. + """ super().__init__() # Save parameters - self.partition_window_size: Tuple[int, int] = partition_window_size + self.attn_grid_window_size: Tuple[int, int] = attn_grid_window_size # Init layers - self.norm_1 = norm_layer(in_channels) self.attention = RelativeSelfAttention( in_channels=in_channels, num_heads=num_heads, - partition_window_size=partition_window_size, + attn_grid_window_size=attn_grid_window_size, attn_drop=attn_drop, - drop=drop, + proj_drop=proj_drop, + use_normalised_qk=use_normalised_qk ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm_2 = norm_layer(in_channels) - def partition_function(self, input: torch.Tensor): + def partition_function(self, X: torch.Tensor): + """ + Partition function. + To be overridden by block or grid partition + + Parameters + ---------- + X : torch.Tensor + Input tensor + + Raises + ------ + NotImplementedError + Should not be called without implementation in the child class + """ raise NotImplementedError def reverse_function( @@ -69,28 +87,46 @@ def reverse_function( partitioned_input: torch.Tensor, original_size: Tuple[int, int], ): + """ + Undo Partition + To be overridden by functions reversing the block or grid partitions + + Parameters + ---------- + partitioned_input : torch.Tensor + Partitioned input + original_size : Tuple[int, int] + Original Input size + + Raises + ------ + NotImplementedError + Should not be called without implementation in the child class + """ raise NotImplementedError - def forward(self, input: torch.Tensor) -> torch.Tensor: - """Forward pass. + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass. - Args: - input (torch.Tensor): Input tensor of the shape [B, C_in, H, W]. + Parameters + ---------- + x : torch.Tensor + Input tensor of the shape [B, C_in, H, W]. - Returns: - output (torch.Tensor): Output tensor of the shape [B, C_out, H (// 2), W (// 2)]. + Returns + ------- + torch.Tensor + Output tensor of the shape [B, C_out, H (// 2), W (// 2)]. """ # Save original shape - B, C, H, W = input.shape + _, C, H, W = x.shape # Perform partition - input_partitioned = self.partition_function(input) + input_partitioned = self.partition_function(x) input_partitioned = input_partitioned.view( - -1, self.partition_window_size[0] * self.partition_window_size[1], C + -1, self.attn_grid_window_size[0] * self.attn_grid_window_size[1], C ) - # Perform normalization, attention, and dropout output = input_partitioned + self.drop_path(self.attention(self.norm_1(input_partitioned))) - # Perform normalization, MLP, and dropout - output = output + self.drop_path(self.mlp(self.norm_2(output))) # Reverse partition output = self.reverse_function(output, (H, W)) return output @@ -101,7 +137,7 @@ def __init__( self, in_channels: int, num_heads: int = 32, - partition_window_size: Tuple[int, int] = (7, 7), + attn_grid_window_size: Tuple[int, int] = (7, 7), attn_drop: float = 0, drop: float = 0, drop_path: float = 0, @@ -110,7 +146,7 @@ def __init__( super().__init__( in_channels, num_heads, - partition_window_size, + attn_grid_window_size, attn_drop, drop, drop_path, @@ -132,7 +168,7 @@ def partition_function(self, input: Tensor): windows = input.view( B, C, - H // self.partition_window_size[0], + H // self.attn_grid_window_size[0], self.partition_window_size[0], W // self.partition_window_size[1], self.partition_window_size[1],