From d3d3bfef2a299a3f05d4700f5235404b7af0a10f Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 19 Nov 2023 05:24:32 +0000 Subject: [PATCH] clean up test model --- tests/test_model.py | 68 ++++++++++++--------------------------------- 1 file changed, 17 insertions(+), 51 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 51180a6c28..5e589dbd60 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib import copy -import gc import os import pathlib import warnings @@ -864,10 +863,9 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.parametrize('use_cache', [False, True]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, - use_cache: bool, tie_word_embeddings: bool): + tie_word_embeddings: bool): if not torch.cuda.device_count() >= world_size: pytest.skip(f'This test requires {world_size} GPUs.') @@ -884,7 +882,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, attn_config={ 'attn_impl': 'torch', }, - use_cache=use_cache, + use_cache=True, tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) @@ -970,7 +968,6 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'torch', pytest.param('flash', marks=pytest.mark.gpu), pytest.param('triton', marks=pytest.mark.gpu), - pytest.param('torch', marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, @@ -998,9 +995,7 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'factor': 1.0, }, }]) -@pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') @@ -1029,7 +1024,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict, 'name': 'baseline_', 'init_std': 0.02, }, - tie_word_embeddings=tie_word_embeddings, + tie_word_embeddings=True, ) mpt = MPTForCausalLM(hf_config) @@ -1107,7 +1102,6 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict, 'torch', pytest.param('flash', marks=pytest.mark.gpu), pytest.param('triton', marks=pytest.mark.gpu), - pytest.param('torch', marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, @@ -1247,7 +1241,6 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, 'torch', pytest.param('flash', marks=pytest.mark.gpu), pytest.param('triton', marks=pytest.mark.gpu), - pytest.param('torch', marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, @@ -1347,18 +1340,12 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict, 'torch', pytest.param('flash', marks=pytest.mark.gpu), pytest.param('triton', marks=pytest.mark.gpu), - pytest.param('torch', marks=pytest.mark.gpu), ]) @pytest.mark.parametrize('generation_kwargs', [{ 'max_new_tokens': 2, - 'num_beams': 4 -}, { - 'max_new_tokens': 2, + 'num_beams': 4, 'top_k': 5, 'penalty_alpha': 0.4 -}, { - 'do_sample': True, - 'top_p': 0.95 }]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, @@ -1425,7 +1412,6 @@ def test_generation_kwargs_dont_crash(attn_impl: str, with get_precision_context('amp_bf16' if composer_device.name == 'gpu' else 'fp32'): - # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) no_padding_input_ids = composer_device.tensor_to_device( no_padding_input_ids) @@ -1442,7 +1428,6 @@ def test_generation_kwargs_dont_crash(attn_impl: str, @pytest.mark.gpu -@pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton']) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -1470,12 +1455,8 @@ def test_generation_kwargs_dont_crash(attn_impl: str, }, }]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_model_to(attention_impl: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_model_to(pos_emb_config: dict, tie_word_embeddings: bool): # test that moving the model to diff devices and dtypes in diff ways does not break the model - if pos_emb_config['alibi'] and attention_impl == 'flash': - pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip(f'dail implementation of rope requires flash attention 2.') @@ -1490,7 +1471,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict, emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': attention_impl, + 'attn_impl': 'torch', **pos_emb_config, }, init_config={ @@ -1514,8 +1495,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict, mpt = mpt.to('cpu') # verify the model still works - if attention_impl == 'torch' and not ( - pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): + if not (pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) @@ -1523,8 +1503,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict, mpt = mpt.float() # verify the model still works - if attention_impl == 'torch' and not ( - pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): + if not (pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) mpt = mpt.to(0) # move to rank0 @@ -1586,16 +1565,11 @@ def test_alibi_vs_hf(): 'factor': 1.0, }, }]) -@pytest.mark.parametrize('output_attentions', [True, False]) -@pytest.mark.parametrize('output_hidden_states', [True, False]) -@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( - attn_impl: str, pos_emb_config: dict, output_attentions: bool, - output_hidden_states: bool, tie_word_embeddings: bool): - # Test that model forward with output_attentions_and_output_hidden_states + attn_impl: str, pos_emb_config: dict): if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if output_attentions and attn_impl in ['flash', 'triton']: + if attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -1624,7 +1598,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( 'name': 'baseline_', 'init_std': 0.02, }, - tie_word_embeddings=tie_word_embeddings, + tie_word_embeddings=True, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1637,20 +1611,16 @@ def test_forward_with_output_attentions_and_output_hidden_states( attention_mask = torch.tensor([[1, 1, 1]]).bool() attention_mask = composer_device.tensor_to_device(attention_mask) - # start with passing the first three tokens through outputs = mpt( input_ids, attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + output_attentions=True, + output_hidden_states=True, ) - if output_attentions: - assert len(outputs.attentions) == n_layers - assert all( - attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) - if output_hidden_states: - assert len(outputs.hidden_states) == n_layers + 1 + assert len(outputs.attentions) == n_layers + assert all(attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) + assert len(outputs.hidden_states) == n_layers + 1 @pytest.mark.gpu @@ -1663,10 +1633,6 @@ def test_hf_init(tmp_path: pathlib.Path, if not torch.cuda.device_count() >= world_size: pytest.skip(f'This test requires {world_size} GPUs.') - torch.cuda.empty_cache() - gc.collect() #just in case - torch.cuda.synchronize() - test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') test_cfg.device = torch.cuda.current_device()