Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Nov 2, 2023
1 parent 73a1230 commit e4a458c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
8 changes: 4 additions & 4 deletions metnet/layers/MultiheadSelfAttention2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class MultiheadSelfAttention2D(nn.Module):
def __init__(
self,
in_channels: int,
attention_channels: int,
num_heads: int,
attention_channels: int = 64,
num_heads: int = 16,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_normalised_qk: bool = True,
Expand All @@ -31,9 +31,9 @@ def __init__(
Number of channels in the input image.
attention_channels : int
Number of channels used for attention computations.
It should be divisible by num_heads.
It should be divisible by num_heads, by default 64
num_heads : int
Number of attention heads.
Number of attention heads, by default 16
attn_drop : float, optional
attention dropout rate, by default 0.0
proj_drop : float, optional
Expand Down
7 changes: 3 additions & 4 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from metnet.layers.StochasticDepth import StochasticDepth
from metnet.layers.SqueezeExcitation import SqueezeExcite
from metnet.layers.MBConv import MBConv
from metnet.layers.PartitionAttention import BlockAttention, GridAttention
from metnet.layers.RelativeSelfAttention import RelativeSelfAttention
from metnet.layers.MultiheadSelfAttention2D import MultiheadSelfAttention2D
import torch


Expand Down Expand Up @@ -32,10 +31,10 @@ def test_mbconv():
assert test_tensor.shape == mb_conv(test_tensor).shape


def test_relative_attention():
def test_multiheaded_self_attention_2D():
n, c, h, w = 1, 3, 16, 16
test_tensor = torch.rand(n, c, h, w)
rel_self_attention = RelativeSelfAttention(c)
rel_self_attention = MultiheadSelfAttention2D(c)
assert test_tensor.shape == rel_self_attention(test_tensor).shape


Expand Down

0 comments on commit e4a458c

Please sign in to comment.