diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 12d7b3de37..13fe50d5cb 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -350,7 +350,8 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) -def test_determinism(attn_impl: str, precision: torch.dtype): +@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu']) +def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str): conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) @@ -358,6 +359,10 @@ def test_determinism(attn_impl: str, precision: torch.dtype): test_cfg.model.attn_config = { 'attn_impl': attn_impl, } + if hasattr(test_cfg.model, 'ffn_config'): + test_cfg.model.ffn_config['ffn_type'] = ffn_type + else: + test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type}) test_cfg.model.init_device = 'cuda:0' test_cfg.device = 'cuda:0' @@ -398,16 +403,22 @@ def test_determinism(attn_impl: str, precision: torch.dtype): @pytest.mark.gpu -def test_loss_fn(): +@pytest.mark.parametrize('ce_loss_implementation', + ['FA_v1_copied', 'FA_imported']) +def test_loss_fn(ce_loss_implementation: str): """Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function. We provide non-zero tolerances to account for small numerics differences between the two loss implementations. """ - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip - except: - pytest.skip('Fused cross entropy was not installed') + if ce_loss_implementation == 'FA_imported': + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + except: + pytest.skip('Fused cross entropy was not installed') + else: + from llmfoundry.models.layers.cross_entropy_loss import \ + CrossEntropyLoss as FusedCrossEntropyLoss # run numerical test in pure fp32 from torch.backends import cuda, cudnn @@ -546,11 +557,15 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) assert isinstance(block.ffn.up_proj, nn.Linear) - assert block.ffn.up_proj.weight.shape == torch.Size( - [hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model]) + assert block.ffn.up_proj.weight.shape == torch.Size([ + int(hf_config.d_model * hf_config.expansion_ratio), + hf_config.d_model + ]) assert isinstance(block.ffn.down_proj, nn.Linear) - assert block.ffn.down_proj.weight.shape == torch.Size( - [hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio]) + assert block.ffn.down_proj.weight.shape == torch.Size([ + hf_config.d_model, + int(hf_config.d_model * hf_config.expansion_ratio) + ]) assert block.resid_attn_dropout.p == 0.2 assert block.resid_ffn_dropout.p == 0.2