Skip to content

Commit

Permalink
restore tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Dec 13, 2023
1 parent aeda62c commit d8b7aa0
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,19 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2):
[('torch', torch.float16), ('torch', torch.bfloat16),
pytest.param('flash', torch.float16, marks=pytest.mark.gpu),
pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)])
def test_determinism(attn_impl: str, precision: torch.dtype):
@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu'])
def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str):
conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)

test_cfg.model.attn_config = {
'attn_impl': attn_impl,
}
if hasattr(test_cfg.model, 'ffn_config'):
test_cfg.model.ffn_config['ffn_type'] = ffn_type
else:
test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type})
test_cfg.model.init_device = 'cuda:0'
test_cfg.device = 'cuda:0'

Expand Down Expand Up @@ -398,16 +403,22 @@ def test_determinism(attn_impl: str, precision: torch.dtype):


@pytest.mark.gpu
def test_loss_fn():
@pytest.mark.parametrize('ce_loss_implementation',
['FA_v1_copied', 'FA_imported'])
def test_loss_fn(ce_loss_implementation: str):
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.
We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
if ce_loss_implementation == 'FA_imported':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
else:
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
Expand Down Expand Up @@ -546,11 +557,15 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
assert block.norm_2 is not None
assert block.norm_2.weight.shape == torch.Size([d_model])
assert isinstance(block.ffn.up_proj, nn.Linear)
assert block.ffn.up_proj.weight.shape == torch.Size(
[hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model])
assert block.ffn.up_proj.weight.shape == torch.Size([
int(hf_config.d_model * hf_config.expansion_ratio),
hf_config.d_model
])
assert isinstance(block.ffn.down_proj, nn.Linear)
assert block.ffn.down_proj.weight.shape == torch.Size(
[hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio])
assert block.ffn.down_proj.weight.shape == torch.Size([
hf_config.d_model,
int(hf_config.d_model * hf_config.expansion_ratio)
])
assert block.resid_attn_dropout.p == 0.2
assert block.resid_ffn_dropout.p == 0.2

Expand Down

0 comments on commit d8b7aa0

Please sign in to comment.