Skip to content

Commit

Permalink
Adds tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Raahul-Singh committed Nov 2, 2023
1 parent 182c9f3 commit f49113d
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from metnet.layers.SqueezeExcitation import SqueezeExcite
from metnet.layers.MBConv import MBConv
from metnet.layers.MultiheadSelfAttention2D import MultiheadSelfAttention2D
from metnet.layers.PartitionAttention import BlockAttention, GridAttention
import torch


Expand Down Expand Up @@ -38,16 +39,17 @@ def test_multiheaded_self_attention_2D():
assert test_tensor.shape == rel_self_attention(test_tensor).shape


# def test_block_attention():
# n, c, h, w = 1, 3, 16, 16
# test_tensor = torch.rand(n, c, h, w)
# block_attention = BlockAttention(c)
def test_block_attention():
n, c, h, w = 1, 3, 16, 16
test_tensor = torch.rand(n, c, h, w)
block_attention = BlockAttention(c)

assert test_tensor.shape == block_attention(test_tensor).shape

# assert test_tensor.shape == block_attention(test_tensor).shape

# def test_grid_attention():
# n, c, h, w = 1, 3, 16, 16
# test_tensor = torch.rand(n, c, h, w)
# grid_attention = GridAttention(c)
def test_grid_attention():
n, c, h, w = 1, 3, 16, 16
test_tensor = torch.rand(n, c, h, w)
grid_attention = GridAttention(c)

# assert test_tensor.shape == grid_attention(test_tensor).shape
assert test_tensor.shape == grid_attention(test_tensor).shape

0 comments on commit f49113d

Please sign in to comment.