From f49113d5a06d7182d9740bebf72559a542bd288e Mon Sep 17 00:00:00 2001 From: Raahul Singh Date: Thu, 2 Nov 2023 22:57:52 +0530 Subject: [PATCH] Adds tests --- tests/test_layers.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 427f65b..c5eb82c 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -2,6 +2,7 @@ from metnet.layers.SqueezeExcitation import SqueezeExcite from metnet.layers.MBConv import MBConv from metnet.layers.MultiheadSelfAttention2D import MultiheadSelfAttention2D +from metnet.layers.PartitionAttention import BlockAttention, GridAttention import torch @@ -38,16 +39,17 @@ def test_multiheaded_self_attention_2D(): assert test_tensor.shape == rel_self_attention(test_tensor).shape -# def test_block_attention(): -# n, c, h, w = 1, 3, 16, 16 -# test_tensor = torch.rand(n, c, h, w) -# block_attention = BlockAttention(c) +def test_block_attention(): + n, c, h, w = 1, 3, 16, 16 + test_tensor = torch.rand(n, c, h, w) + block_attention = BlockAttention(c) + + assert test_tensor.shape == block_attention(test_tensor).shape -# assert test_tensor.shape == block_attention(test_tensor).shape -# def test_grid_attention(): -# n, c, h, w = 1, 3, 16, 16 -# test_tensor = torch.rand(n, c, h, w) -# grid_attention = GridAttention(c) +def test_grid_attention(): + n, c, h, w = 1, 3, 16, 16 + test_tensor = torch.rand(n, c, h, w) + grid_attention = GridAttention(c) -# assert test_tensor.shape == grid_attention(test_tensor).shape + assert test_tensor.shape == grid_attention(test_tensor).shape