Skip to content

Commit

Permalink
changes for pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay committed Jul 29, 2024
1 parent 3c8471d commit 1500164
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions tests/models/layers/test_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,41 @@

import pytest
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn as nn

from llmfoundry.models.layers.layer_builders import build_ffn
from llmfoundry.models.layers.ffn import quickgelu_activation
from llmfoundry.models.layers.layer_builders import build_ffn


@pytest.mark.gpu
def test_quickgelu_activation():
d_model = 32
expansion_ratio = 1
no_bias = True
ffn_config={
ffn_config = {
'ffn_act_fn': {
'name': 'quick_gelu',
},
'ffn_type': 'mptmlp',
}
rank: int = dist.get_rank()
device: torch.device = torch.device(f'cuda:{rank}')
device_str = f'cuda:{rank}'
device: torch.device = torch.device(device_str)

ffn1 = build_ffn(
name=ffn_config['ffn_type'],
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
device=device_str,
bias=not no_bias,
ffn_kwargs=ffn_config,
)
assert (
ffn1.act == quickgelu_activation
), f"Expected quick_gelu activation function, got {ffn1.act}"
), f'Expected quick_gelu activation function, got {ffn1.act}'

ffn_config={
ffn_config = {
'ffn_act_fn': {
'name': 'gelu',
},
Expand All @@ -46,7 +47,7 @@ def test_quickgelu_activation():
name=ffn_config['ffn_type'],
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
device=device_str,
bias=not no_bias,
ffn_kwargs=ffn_config,
)
Expand All @@ -59,14 +60,14 @@ def num_params(model: nn.Module) -> int:
ffn2_numparams = num_params(ffn2)
assert (
ffn1_numparams == ffn2_numparams
), "Only activation paths should have changed, re-check modeling!"
), 'Only activation paths should have changed, re-check modeling!'

input_ = torch.rand(1, d_model, device=device)
output1 = ffn1(input_)
output2 = ffn2(input_)
assert (
output1.numel() == output2.numel()
), "Only activation paths should have changed, re-check modeling!"
), 'Only activation paths should have changed, re-check modeling!'
assert (
not torch.allclose(output1, output2)
), "Functions are different, outputs should not match!"
), 'Functions are different, outputs should not match!'

0 comments on commit 1500164

Please sign in to comment.