-
-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1db1bab
commit a43181b
Showing
1 changed file
with
253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
from typing import Tuple, Type | ||
|
||
import torch | ||
from timm.layers import DropPath | ||
from torch import Tensor, nn | ||
|
||
from metnet.layers.RelativeSelfAttention import RelativeSelfAttention | ||
|
||
|
||
class PartitionAttention(nn.Module): | ||
"""PartitionAttention block. | ||
With block partition: | ||
x ← x + Unblock(RelAttention(Block(LN(x)))) | ||
x ← x + MLP(LN(x)) | ||
With grid partition: | ||
x ← x + Ungrid(RelAttention(Grid(LN(x)))) | ||
x ← x + MLP(LN(x)) | ||
Layer Normalization (LN) is applied after the grid/window partition to prevent multiple reshaping operations. | ||
Grid/window reverse (Unblock/Ungrid) is performed on the final output for the same reason. | ||
Args: | ||
in_channels (int): Number of input channels. | ||
partition_function (Callable): Partition function to be utilized (grid or window partition). | ||
reverse_function (Callable): Reverse function to be utilized (grid or window reverse). | ||
num_heads (int, optional): Number of attention heads. Default 32 | ||
partition_window_size (Tuple[int, int], optional): Grid/Window size to be utilized. Default (7, 7) | ||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | ||
drop (float, optional): Dropout ratio of output. Default: 0.0 | ||
drop_path (float, optional): Dropout ratio of path. Default: 0.0 | ||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default: 4.0 | ||
act_layer (Type[nn.Module], optional): Type of activation layer to be utilized. Default: nn.GELU | ||
norm_layer (Type[nn.Module], optional): Type of normalization layer to be utilized. Default: nn.BatchNorm2d | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
num_heads: int = 32, | ||
partition_window_size: Tuple[int, int] = (7, 7), | ||
attn_drop: float = 0.0, | ||
drop: float = 0.0, | ||
drop_path: float = 0.0, | ||
norm_layer: Type[nn.Module] = nn.LayerNorm, | ||
) -> None: | ||
"""Constructor method""" | ||
super().__init__() | ||
# Save parameters | ||
self.partition_window_size: Tuple[int, int] = partition_window_size | ||
# Init layers | ||
self.norm_1 = norm_layer(in_channels) | ||
self.attention = RelativeSelfAttention( | ||
in_channels=in_channels, | ||
num_heads=num_heads, | ||
partition_window_size=partition_window_size, | ||
attn_drop=attn_drop, | ||
drop=drop, | ||
) | ||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | ||
self.norm_2 = norm_layer(in_channels) | ||
|
||
def partition_function(self, input: torch.Tensor): | ||
raise NotImplementedError | ||
|
||
def reverse_function( | ||
self, | ||
partitioned_input: torch.Tensor, | ||
original_size: Tuple[int, int], | ||
): | ||
raise NotImplementedError | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
"""Forward pass. | ||
Args: | ||
input (torch.Tensor): Input tensor of the shape [B, C_in, H, W]. | ||
Returns: | ||
output (torch.Tensor): Output tensor of the shape [B, C_out, H (// 2), W (// 2)]. | ||
""" | ||
# Save original shape | ||
B, C, H, W = input.shape | ||
# Perform partition | ||
input_partitioned = self.partition_function(input) | ||
input_partitioned = input_partitioned.view( | ||
-1, self.partition_window_size[0] * self.partition_window_size[1], C | ||
) | ||
# Perform normalization, attention, and dropout | ||
output = input_partitioned + self.drop_path(self.attention(self.norm_1(input_partitioned))) | ||
# Perform normalization, MLP, and dropout | ||
output = output + self.drop_path(self.mlp(self.norm_2(output))) | ||
# Reverse partition | ||
output = self.reverse_function(output, (H, W)) | ||
return output | ||
|
||
|
||
class BlockAttention(PartitionAttention): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
num_heads: int = 32, | ||
partition_window_size: Tuple[int, int] = (7, 7), | ||
attn_drop: float = 0, | ||
drop: float = 0, | ||
drop_path: float = 0, | ||
norm_layer: Type[nn.Module] = nn.LayerNorm, | ||
) -> None: | ||
super().__init__( | ||
in_channels, | ||
num_heads, | ||
partition_window_size, | ||
attn_drop, | ||
drop, | ||
drop_path, | ||
norm_layer, | ||
) | ||
|
||
def partition_function(self, input: Tensor): | ||
"""Window partition function. | ||
Args: | ||
input (torch.Tensor): Input tensor of the shape [B, C, H, W]. | ||
Returns: | ||
windows (torch.Tensor): Unfolded input tensor of the shape [B * windows, partition_size[0], partition_size[1], C]. | ||
""" | ||
# Get size of input | ||
B, C, H, W = input.shape | ||
# Unfold input | ||
windows = input.view( | ||
B, | ||
C, | ||
H // self.partition_window_size[0], | ||
self.partition_window_size[0], | ||
W // self.partition_window_size[1], | ||
self.partition_window_size[1], | ||
) | ||
# Permute and reshape to [B * windows, self.partition_window_size[0], self.partition_window_size[1], channels] | ||
windows = ( | ||
windows.permute(0, 2, 4, 3, 5, 1) | ||
.contiguous() | ||
.view(-1, self.partition_window_size[0], self.partition_window_size[1], C) | ||
) | ||
return windows | ||
|
||
def reverse_function( | ||
self, | ||
partitioned_input: torch.Tensor, | ||
original_size: Tuple[int, int], | ||
): | ||
"""Reverses the window partition. | ||
Args: | ||
partitioned_input (torch.Tensor): Window tensor of the shape [B * partitioned_input, partition_size[0], partition_size[1], C]. | ||
original_size (Tuple[int, int]): Original shape. | ||
Returns: | ||
output (torch.Tensor): Folded output tensor of the shape [B, C, original_size[0], original_size[1]]. | ||
""" | ||
# Get height and width | ||
H, W = original_size | ||
# Compute original batch size | ||
B = int( | ||
partitioned_input.shape[0] | ||
/ (H * W / self.partition_window_size[0] / self.partition_window_size[1]) | ||
) | ||
# Fold grid tensor | ||
output = partitioned_input.view( | ||
B, | ||
H // self.partition_window_size[0], | ||
W // self.partition_window_size[1], | ||
self.partition_window_size[0], | ||
self.partition_window_size[1], | ||
-1, | ||
) | ||
output = output.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, -1, H, W) | ||
return output | ||
|
||
|
||
class GridAttention(PartitionAttention): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
num_heads: int = 32, | ||
partition_window_size: Tuple[int, int] = (7, 7), | ||
attn_drop: float = 0, | ||
drop: float = 0, | ||
drop_path: float = 0, | ||
norm_layer: Type[nn.Module] = nn.LayerNorm, | ||
) -> None: | ||
super().__init__( | ||
in_channels, | ||
num_heads, | ||
partition_window_size, | ||
attn_drop, | ||
drop, | ||
drop_path, | ||
norm_layer, | ||
) | ||
|
||
def partition_function(self, input: Tensor): | ||
"""Grid partition function. | ||
Args: | ||
input (torch.Tensor): Input tensor of the shape [B, C, H, W]. | ||
Returns: | ||
grid (torch.Tensor): Unfolded input tensor of the shape [B * grids, grid_size[0], grid_size[1], C]. | ||
""" | ||
# Get size of input | ||
B, C, H, W = input.shape | ||
# Unfold input | ||
grid = input.view( | ||
B, | ||
C, | ||
self.partition_window_size[0], | ||
H // self.partition_window_size[0], | ||
self.partition_window_size[1], | ||
W // self.partition_window_size[1], | ||
) | ||
# Permute and reshape [B * (H // self.partition_window_size[0]) * (W // self.partition_window_size[1]), self.partition_window_size[0], window_size[1], C] | ||
grid = ( | ||
grid.permute(0, 3, 5, 2, 4, 1) | ||
.contiguous() | ||
.view(-1, self.partition_window_size[0], self.partition_window_size[1], C) | ||
) | ||
return grid | ||
|
||
def reverse_function( | ||
self, | ||
partitioned_input: torch.Tensor, | ||
original_size: Tuple[int, int], | ||
): | ||
# Get height, width, and channels | ||
(H, W), C = original_size, partitioned_input.shape[-1] | ||
# Compute original batch size | ||
B = int( | ||
partitioned_input.shape[0] | ||
/ (H * W / self.partition_window_size[0] / self.partition_window_size[1]) | ||
) | ||
# Fold partitioned_input tensor | ||
output = partitioned_input.view( | ||
B, | ||
H // self.partition_window_size[0], | ||
W // self.partition_window_size[1], | ||
self.partition_window_size[0], | ||
self.partition_window_size[1], | ||
C, | ||
) | ||
output = output.permute(0, 5, 3, 1, 4, 2).contiguous().view(B, C, H, W) | ||
return output |