From 3c8471d5eaab81cac12e287bc5259a503c125625 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Jul 2024 22:29:56 +0000 Subject: [PATCH] fix comments --- tests/models/layers/test_ffn.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py index 4bd28de169..d6098bc80c 100644 --- a/tests/models/layers/test_ffn.py +++ b/tests/models/layers/test_ffn.py @@ -7,6 +7,7 @@ import torch.distributed as dist from llmfoundry.models.layers.layer_builders import build_ffn +from llmfoundry.models.layers.ffn import quickgelu_activation @pytest.mark.gpu @@ -31,6 +32,9 @@ def test_quickgelu_activation(): bias=not no_bias, ffn_kwargs=ffn_config, ) + assert ( + ffn1.act == quickgelu_activation + ), f"Expected quick_gelu activation function, got {ffn1.act}" ffn_config={ 'ffn_act_fn': { @@ -53,10 +57,16 @@ def num_params(model: nn.Module) -> int: ffn1_numparams = num_params(ffn1) ffn2_numparams = num_params(ffn2) - assert ffn1_numparams == ffn2_numparams, "Only activation paths should have changed, re-check modeling!" + assert ( + ffn1_numparams == ffn2_numparams + ), "Only activation paths should have changed, re-check modeling!" input_ = torch.rand(1, d_model, device=device) output1 = ffn1(input_) output2 = ffn2(input_) - assert output1.numel() == output2.numel(), "Only activation paths should have changed, re-check modeling!" - assert not torch.allclose(output1, output2) + assert ( + output1.numel() == output2.numel() + ), "Only activation paths should have changed, re-check modeling!" + assert ( + not torch.allclose(output1, output2) + ), "Functions are different, outputs should not match!"