diff --git a/tests/test_layers.py b/tests/test_layers.py index 427f65b..c5eb82c 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -2,6 +2,7 @@ from metnet.layers.SqueezeExcitation import SqueezeExcite from metnet.layers.MBConv import MBConv from metnet.layers.MultiheadSelfAttention2D import MultiheadSelfAttention2D +from metnet.layers.PartitionAttention import BlockAttention, GridAttention import torch @@ -38,16 +39,17 @@ def test_multiheaded_self_attention_2D(): assert test_tensor.shape == rel_self_attention(test_tensor).shape -# def test_block_attention(): -# n, c, h, w = 1, 3, 16, 16 -# test_tensor = torch.rand(n, c, h, w) -# block_attention = BlockAttention(c) +def test_block_attention(): + n, c, h, w = 1, 3, 16, 16 + test_tensor = torch.rand(n, c, h, w) + block_attention = BlockAttention(c) + + assert test_tensor.shape == block_attention(test_tensor).shape -# assert test_tensor.shape == block_attention(test_tensor).shape -# def test_grid_attention(): -# n, c, h, w = 1, 3, 16, 16 -# test_tensor = torch.rand(n, c, h, w) -# grid_attention = GridAttention(c) +def test_grid_attention(): + n, c, h, w = 1, 3, 16, 16 + test_tensor = torch.rand(n, c, h, w) + grid_attention = GridAttention(c) -# assert test_tensor.shape == grid_attention(test_tensor).shape + assert test_tensor.shape == grid_attention(test_tensor).shape