From d1df05c97b78321c018a58f54fd015675c2b8d1c Mon Sep 17 00:00:00 2001 From: root Date: Mon, 13 Nov 2023 20:27:58 +0000 Subject: [PATCH] updt tests to guard against numerical issues --- tests/test_model.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 7a7735e1c6..18ce7190a2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -743,10 +743,13 @@ def test_advanced_mask_building(attention_impl: str): assert torch.equal(attn_bias, expected_attn_bias) -@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), - ('flash', 'gpu'), - ('triton', 'gpu'), - ('torch', 'gpu')]) +@pytest.mark.parametrize('attention_impl,device,precision', [ + ('torch', 'cpu', 'fp32'), + ('flash', 'gpu', 'amp_bf16'), + ('triton', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'fp32'), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -774,8 +777,8 @@ def test_advanced_mask_building(attention_impl: str): }, }]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_generate(attention_impl: str, device: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_generate(attention_impl: str, device: str, precision: str, + pos_emb_config: dict, tie_word_embeddings: bool): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': @@ -789,6 +792,8 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict, device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') + if attention_impl == 'torch' and precision == 'amp_bf16' and tie_word_embeddings == False: + pytest.skip(f'This test configuration has precision / sampling issues.') composer_device = get_device(device) @@ -808,10 +813,6 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict, tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) - if not tie_word_embeddings: - assert mpt.lm_head is not None - with torch.no_grad(): - mpt.lm_head.weight.copy_(mpt.transformer.wte.weight) mpt = composer_device.module_to_device(mpt) mpt.eval() @@ -844,8 +845,7 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict, batched_attention_mask = composer_device.tensor_to_device( batched_attention_mask) - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context(precision): # check that a batch with different amounts of padding doesn't crash # and produces the right output shape batched_generation = mpt.generate(input_ids=batched_input_ids, @@ -1192,10 +1192,6 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict, tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) - if not tie_word_embeddings: - assert mpt.lm_head is not None - with torch.no_grad(): - mpt.lm_head.weight.copy_(mpt.transformer.wte.weight) mpt = composer_device.module_to_device(mpt) mpt.eval() @@ -1263,7 +1259,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict, torch.testing.assert_close( second_output.logits, full_output.logits[:, -1, :].unsqueeze(1), - atol=1e-2, + atol=1e-1, rtol=1e-2, ) @@ -1337,10 +1333,6 @@ def test_generate_with_past_kv(attn_impl: str, device: str, tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) - if not tie_word_embeddings: - assert mpt.lm_head is not None - with torch.no_grad(): - mpt.lm_head.weight.copy_(mpt.transformer.wte.weight) mpt = composer_device.module_to_device(mpt) mpt.eval() @@ -1357,7 +1349,8 @@ def test_generate_with_past_kv(attn_impl: str, device: str, with mock.patch.object(MPTForCausalLM, 'forward', autospec=True) as forward_mocked: forward_mocked.return_value = CausalLMOutputWithPast( - logits=torch.randn((1, 3, hf_config.vocab_size)), + logits=composer_device.tensor_to_device( + torch.randn((1, 3, hf_config.vocab_size))), past_key_values=[(torch.randn(1, 3, hf_config.d_model), torch.randn(1, 3, hf_config.d_model)) for _ in range(hf_config.n_layers)])