diff --git a/metnet/layers/MultiheadSelfAttention2D.py b/metnet/layers/MultiheadSelfAttention2D.py index 9c47c7c..6f3a100 100644 --- a/metnet/layers/MultiheadSelfAttention2D.py +++ b/metnet/layers/MultiheadSelfAttention2D.py @@ -15,8 +15,8 @@ class MultiheadSelfAttention2D(nn.Module): def __init__( self, in_channels: int, - attention_channels: int, - num_heads: int, + attention_channels: int = 64, + num_heads: int = 16, attn_drop: float = 0.0, proj_drop: float = 0.0, use_normalised_qk: bool = True, @@ -31,9 +31,9 @@ def __init__( Number of channels in the input image. attention_channels : int Number of channels used for attention computations. - It should be divisible by num_heads. + It should be divisible by num_heads, by default 64 num_heads : int - Number of attention heads. + Number of attention heads, by default 16 attn_drop : float, optional attention dropout rate, by default 0.0 proj_drop : float, optional diff --git a/tests/test_layers.py b/tests/test_layers.py index 4ec219c..427f65b 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,8 +1,7 @@ from metnet.layers.StochasticDepth import StochasticDepth from metnet.layers.SqueezeExcitation import SqueezeExcite from metnet.layers.MBConv import MBConv -from metnet.layers.PartitionAttention import BlockAttention, GridAttention -from metnet.layers.RelativeSelfAttention import RelativeSelfAttention +from metnet.layers.MultiheadSelfAttention2D import MultiheadSelfAttention2D import torch @@ -32,10 +31,10 @@ def test_mbconv(): assert test_tensor.shape == mb_conv(test_tensor).shape -def test_relative_attention(): +def test_multiheaded_self_attention_2D(): n, c, h, w = 1, 3, 16, 16 test_tensor = torch.rand(n, c, h, w) - rel_self_attention = RelativeSelfAttention(c) + rel_self_attention = MultiheadSelfAttention2D(c) assert test_tensor.shape == rel_self_attention(test_tensor).shape