Skip to content

Commit

Permalink
add partioned attention blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Oct 24, 2023
1 parent 1db1bab commit a43181b
Showing 1 changed file with 253 additions and 0 deletions.
253 changes: 253 additions & 0 deletions metnet/layers/PartitionAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
from typing import Tuple, Type

import torch
from timm.layers import DropPath
from torch import Tensor, nn

from metnet.layers.RelativeSelfAttention import RelativeSelfAttention


class PartitionAttention(nn.Module):
"""PartitionAttention block.
With block partition:
x ← x + Unblock(RelAttention(Block(LN(x))))
x ← x + MLP(LN(x))
With grid partition:
x ← x + Ungrid(RelAttention(Grid(LN(x))))
x ← x + MLP(LN(x))
Layer Normalization (LN) is applied after the grid/window partition to prevent multiple reshaping operations.
Grid/window reverse (Unblock/Ungrid) is performed on the final output for the same reason.
Args:
in_channels (int): Number of input channels.
partition_function (Callable): Partition function to be utilized (grid or window partition).
reverse_function (Callable): Reverse function to be utilized (grid or window reverse).
num_heads (int, optional): Number of attention heads. Default 32
partition_window_size (Tuple[int, int], optional): Grid/Window size to be utilized. Default (7, 7)
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
drop (float, optional): Dropout ratio of output. Default: 0.0
drop_path (float, optional): Dropout ratio of path. Default: 0.0
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default: 4.0
act_layer (Type[nn.Module], optional): Type of activation layer to be utilized. Default: nn.GELU
norm_layer (Type[nn.Module], optional): Type of normalization layer to be utilized. Default: nn.BatchNorm2d
"""

def __init__(
self,
in_channels: int,
num_heads: int = 32,
partition_window_size: Tuple[int, int] = (7, 7),
attn_drop: float = 0.0,
drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
"""Constructor method"""
super().__init__()
# Save parameters
self.partition_window_size: Tuple[int, int] = partition_window_size
# Init layers
self.norm_1 = norm_layer(in_channels)
self.attention = RelativeSelfAttention(
in_channels=in_channels,
num_heads=num_heads,
partition_window_size=partition_window_size,
attn_drop=attn_drop,
drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm_2 = norm_layer(in_channels)

def partition_function(self, input: torch.Tensor):
raise NotImplementedError

def reverse_function(
self,
partitioned_input: torch.Tensor,
original_size: Tuple[int, int],
):
raise NotImplementedError

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape [B, C_in, H, W].
Returns:
output (torch.Tensor): Output tensor of the shape [B, C_out, H (// 2), W (// 2)].
"""
# Save original shape
B, C, H, W = input.shape
# Perform partition
input_partitioned = self.partition_function(input)
input_partitioned = input_partitioned.view(
-1, self.partition_window_size[0] * self.partition_window_size[1], C
)
# Perform normalization, attention, and dropout
output = input_partitioned + self.drop_path(self.attention(self.norm_1(input_partitioned)))
# Perform normalization, MLP, and dropout
output = output + self.drop_path(self.mlp(self.norm_2(output)))
# Reverse partition
output = self.reverse_function(output, (H, W))
return output


class BlockAttention(PartitionAttention):
def __init__(
self,
in_channels: int,
num_heads: int = 32,
partition_window_size: Tuple[int, int] = (7, 7),
attn_drop: float = 0,
drop: float = 0,
drop_path: float = 0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__(
in_channels,
num_heads,
partition_window_size,
attn_drop,
drop,
drop_path,
norm_layer,
)

def partition_function(self, input: Tensor):
"""Window partition function.
Args:
input (torch.Tensor): Input tensor of the shape [B, C, H, W].
Returns:
windows (torch.Tensor): Unfolded input tensor of the shape [B * windows, partition_size[0], partition_size[1], C].
"""
# Get size of input
B, C, H, W = input.shape
# Unfold input
windows = input.view(
B,
C,
H // self.partition_window_size[0],
self.partition_window_size[0],
W // self.partition_window_size[1],
self.partition_window_size[1],
)
# Permute and reshape to [B * windows, self.partition_window_size[0], self.partition_window_size[1], channels]
windows = (
windows.permute(0, 2, 4, 3, 5, 1)
.contiguous()
.view(-1, self.partition_window_size[0], self.partition_window_size[1], C)
)
return windows

def reverse_function(
self,
partitioned_input: torch.Tensor,
original_size: Tuple[int, int],
):
"""Reverses the window partition.
Args:
partitioned_input (torch.Tensor): Window tensor of the shape [B * partitioned_input, partition_size[0], partition_size[1], C].
original_size (Tuple[int, int]): Original shape.
Returns:
output (torch.Tensor): Folded output tensor of the shape [B, C, original_size[0], original_size[1]].
"""
# Get height and width
H, W = original_size
# Compute original batch size
B = int(
partitioned_input.shape[0]
/ (H * W / self.partition_window_size[0] / self.partition_window_size[1])
)
# Fold grid tensor
output = partitioned_input.view(
B,
H // self.partition_window_size[0],
W // self.partition_window_size[1],
self.partition_window_size[0],
self.partition_window_size[1],
-1,
)
output = output.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, -1, H, W)
return output


class GridAttention(PartitionAttention):
def __init__(
self,
in_channels: int,
num_heads: int = 32,
partition_window_size: Tuple[int, int] = (7, 7),
attn_drop: float = 0,
drop: float = 0,
drop_path: float = 0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__(
in_channels,
num_heads,
partition_window_size,
attn_drop,
drop,
drop_path,
norm_layer,
)

def partition_function(self, input: Tensor):
"""Grid partition function.
Args:
input (torch.Tensor): Input tensor of the shape [B, C, H, W].
Returns:
grid (torch.Tensor): Unfolded input tensor of the shape [B * grids, grid_size[0], grid_size[1], C].
"""
# Get size of input
B, C, H, W = input.shape
# Unfold input
grid = input.view(
B,
C,
self.partition_window_size[0],
H // self.partition_window_size[0],
self.partition_window_size[1],
W // self.partition_window_size[1],
)
# Permute and reshape [B * (H // self.partition_window_size[0]) * (W // self.partition_window_size[1]), self.partition_window_size[0], window_size[1], C]
grid = (
grid.permute(0, 3, 5, 2, 4, 1)
.contiguous()
.view(-1, self.partition_window_size[0], self.partition_window_size[1], C)
)
return grid

def reverse_function(
self,
partitioned_input: torch.Tensor,
original_size: Tuple[int, int],
):
# Get height, width, and channels
(H, W), C = original_size, partitioned_input.shape[-1]
# Compute original batch size
B = int(
partitioned_input.shape[0]
/ (H * W / self.partition_window_size[0] / self.partition_window_size[1])
)
# Fold partitioned_input tensor
output = partitioned_input.view(
B,
H // self.partition_window_size[0],
W // self.partition_window_size[1],
self.partition_window_size[0],
self.partition_window_size[1],
C,
)
output = output.permute(0, 5, 3, 1, 4, 2).contiguous().view(B, C, H, W)
return output

0 comments on commit a43181b

Please sign in to comment.