Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 24, 2023
1 parent 11861f5 commit 63d9596
Showing 1 changed file with 87 additions and 51 deletions.
138 changes: 87 additions & 51 deletions metnet/layers/PartitionAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,89 +8,125 @@


class PartitionAttention(nn.Module):
"""PartitionAttention block.
With block partition:
x ← x + Unblock(RelAttention(Block(LN(x))))
x ← x + MLP(LN(x))
With grid partition:
x ← x + Ungrid(RelAttention(Grid(LN(x))))
x ← x + MLP(LN(x))
Layer Normalization (LN) is applied after the grid/window partition to prevent multiple reshaping operations.
Grid/window reverse (Unblock/Ungrid) is performed on the final output for the same reason.
Args:
in_channels (int): Number of input channels.
partition_function (Callable): Partition function to be utilized (grid or window partition).
reverse_function (Callable): Reverse function to be utilized (grid or window reverse).
num_heads (int, optional): Number of attention heads. Default 32
partition_window_size (Tuple[int, int], optional): Grid/Window size to be utilized. Default (7, 7)
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
drop (float, optional): Dropout ratio of output. Default: 0.0
drop_path (float, optional): Dropout ratio of path. Default: 0.0
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default: 4.0
act_layer (Type[nn.Module], optional): Type of activation layer to be utilized. Default: nn.GELU
norm_layer (Type[nn.Module], optional): Type of normalization layer to be utilized. Default: nn.BatchNorm2d
"""

def __init__(
self,
in_channels: int,
num_heads: int = 32,
partition_window_size: Tuple[int, int] = (7, 7),
attn_grid_window_size: Tuple[int, int] = (8, 8),
attn_drop: float = 0.0,
drop: float = 0.0,
proj_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
use_normalised_qk: bool = True
) -> None:
"""Constructor method"""
"""
PartitionAttention block
Implements the common functionality for block and grid attention.
With block partition:
x ← x + Unblock(RelAttention(Block(x)))
With grid partition:
x ← x + Ungrid(RelAttention(Grid(x)))
Parameters
----------
in_channels : int
Number of input channels.
num_heads : int, optional
Number of attention heads, by default 32
attn_grid_window_size : Tuple[int, int], optional
Grid/Window size to be utilized, by default (8, 8)
attn_drop : float, optional
Dropout ratio of attention weight, by default 0.0
proj_drop : float, optional
Dropout ratio of output, by default 0.0
drop_path : float, optional
Stochastic depth, by default 0.0
use_normalised_qk : bool, optional
Normalise queries and keys as done in Metnet 3, by default True.
Notes
-----
Specific to Metnet 3 implementation
TODO: Add the MLP as an optional parameter.
"""
super().__init__()
# Save parameters
self.partition_window_size: Tuple[int, int] = partition_window_size
self.attn_grid_window_size: Tuple[int, int] = attn_grid_window_size
# Init layers
self.norm_1 = norm_layer(in_channels)
self.attention = RelativeSelfAttention(
in_channels=in_channels,
num_heads=num_heads,
partition_window_size=partition_window_size,
attn_grid_window_size=attn_grid_window_size,
attn_drop=attn_drop,
drop=drop,
proj_drop=proj_drop,
use_normalised_qk=use_normalised_qk
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm_2 = norm_layer(in_channels)

def partition_function(self, input: torch.Tensor):
def partition_function(self, X: torch.Tensor):
"""
Partition function.
To be overridden by block or grid partition
Parameters
----------
X : torch.Tensor
Input tensor
Raises
------
NotImplementedError
Should not be called without implementation in the child class
"""
raise NotImplementedError

def reverse_function(
self,
partitioned_input: torch.Tensor,
original_size: Tuple[int, int],
):
"""
Undo Partition
To be overridden by functions reversing the block or grid partitions
Parameters
----------
partitioned_input : torch.Tensor
Partitioned input
original_size : Tuple[int, int]
Original Input size
Raises
------
NotImplementedError
Should not be called without implementation in the child class
"""
raise NotImplementedError

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape [B, C_in, H, W].
Parameters
----------
x : torch.Tensor
Input tensor of the shape [B, C_in, H, W].
Returns:
output (torch.Tensor): Output tensor of the shape [B, C_out, H (// 2), W (// 2)].
Returns
-------
torch.Tensor
Output tensor of the shape [B, C_out, H (// 2), W (// 2)].
"""
# Save original shape
B, C, H, W = input.shape
_, C, H, W = x.shape
# Perform partition
input_partitioned = self.partition_function(input)
input_partitioned = self.partition_function(x)
input_partitioned = input_partitioned.view(
-1, self.partition_window_size[0] * self.partition_window_size[1], C
-1, self.attn_grid_window_size[0] * self.attn_grid_window_size[1], C
)
# Perform normalization, attention, and dropout
output = input_partitioned + self.drop_path(self.attention(self.norm_1(input_partitioned)))
# Perform normalization, MLP, and dropout
output = output + self.drop_path(self.mlp(self.norm_2(output)))
# Reverse partition
output = self.reverse_function(output, (H, W))
return output
Expand All @@ -101,7 +137,7 @@ def __init__(
self,
in_channels: int,
num_heads: int = 32,
partition_window_size: Tuple[int, int] = (7, 7),
attn_grid_window_size: Tuple[int, int] = (7, 7),
attn_drop: float = 0,
drop: float = 0,
drop_path: float = 0,
Expand All @@ -110,7 +146,7 @@ def __init__(
super().__init__(
in_channels,
num_heads,
partition_window_size,
attn_grid_window_size,
attn_drop,
drop,
drop_path,
Expand All @@ -132,7 +168,7 @@ def partition_function(self, input: Tensor):
windows = input.view(
B,
C,
H // self.partition_window_size[0],
H // self.attn_grid_window_size[0],
self.partition_window_size[0],
W // self.partition_window_size[1],
self.partition_window_size[1],
Expand Down

0 comments on commit 63d9596

Please sign in to comment.