diff --git a/tests/test_readme.py b/tests/test_readme.py index 258a0477..06787ee1 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -1,8 +1,12 @@ +import os +os.environ['TYPECHECK'] = 'True' + import torch import pytest from alphafold3_pytorch import ( - PairformerStack + PairformerStack, + MSAModule ) def test_pairformer(): @@ -24,3 +28,25 @@ def test_pairformer(): assert single.shape == single_out.shape assert pairwise.shape == pairwise_out.shape + +def test_msa_module(): + + single = torch.randn(2, 16, 512) + msa = torch.randn(2, 7, 16, 64) + pairwise = torch.randn(2, 16, 16, 128) + mask = torch.randint(0, 2, (2, 16)).bool() + + msa_module = MSAModule( + dim_single = 512, + dim_pairwise = 128, + dim_msa = 64 + ) + + pairwise_out = msa_module( + single_repr = single, + msa = msa, + pairwise_repr = pairwise, + mask = mask + ) + + assert pairwise.shape == pairwise_out.shape