Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay committed Jul 29, 2024
1 parent d142aa7 commit 3c8471d
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions tests/models/layers/test_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist

from llmfoundry.models.layers.layer_builders import build_ffn
from llmfoundry.models.layers.ffn import quickgelu_activation


@pytest.mark.gpu
Expand All @@ -31,6 +32,9 @@ def test_quickgelu_activation():
bias=not no_bias,
ffn_kwargs=ffn_config,
)
assert (
ffn1.act == quickgelu_activation
), f"Expected quick_gelu activation function, got {ffn1.act}"

ffn_config={
'ffn_act_fn': {
Expand All @@ -53,10 +57,16 @@ def num_params(model: nn.Module) -> int:

ffn1_numparams = num_params(ffn1)
ffn2_numparams = num_params(ffn2)
assert ffn1_numparams == ffn2_numparams, "Only activation paths should have changed, re-check modeling!"
assert (
ffn1_numparams == ffn2_numparams
), "Only activation paths should have changed, re-check modeling!"

input_ = torch.rand(1, d_model, device=device)
output1 = ffn1(input_)
output2 = ffn2(input_)
assert output1.numel() == output2.numel(), "Only activation paths should have changed, re-check modeling!"
assert not torch.allclose(output1, output2)
assert (
output1.numel() == output2.numel()
), "Only activation paths should have changed, re-check modeling!"
assert (
not torch.allclose(output1, output2)
), "Functions are different, outputs should not match!"

0 comments on commit 3c8471d

Please sign in to comment.