Skip to content

Commit

Permalink
clean up test model
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 19, 2023
1 parent eea448f commit d3d3bfe
Showing 1 changed file with 17 additions and 51 deletions.
68 changes: 17 additions & 51 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import copy
import gc
import os
import pathlib
import warnings
Expand Down Expand Up @@ -864,10 +863,9 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict,

@pytest.mark.gpu
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_cache', [False, True])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int,
use_cache: bool, tie_word_embeddings: bool):
tie_word_embeddings: bool):
if not torch.cuda.device_count() >= world_size:
pytest.skip(f'This test requires {world_size} GPUs.')

Expand All @@ -884,7 +882,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int,
attn_config={
'attn_impl': 'torch',
},
use_cache=use_cache,
use_cache=True,
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
Expand Down Expand Up @@ -970,7 +968,6 @@ def test_save_from_pretrained(tmp_path: pathlib.Path):
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
Expand Down Expand Up @@ -998,9 +995,7 @@ def test_save_from_pretrained(tmp_path: pathlib.Path):
'factor': 1.0,
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict):
# Tests that the result is the same with or without padding when using kv caching
if pos_emb_config['alibi'] and attn_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')
Expand Down Expand Up @@ -1029,7 +1024,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict,
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
tie_word_embeddings=True,
)

mpt = MPTForCausalLM(hf_config)
Expand Down Expand Up @@ -1107,7 +1102,6 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict,
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
Expand Down Expand Up @@ -1247,7 +1241,6 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict,
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
Expand Down Expand Up @@ -1347,18 +1340,12 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict,
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
pytest.param('triton', marks=pytest.mark.gpu),
pytest.param('torch', marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('generation_kwargs', [{
'max_new_tokens': 2,
'num_beams': 4
}, {
'max_new_tokens': 2,
'num_beams': 4,
'top_k': 5,
'penalty_alpha': 0.4
}, {
'do_sample': True,
'top_p': 0.95
}])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
Expand Down Expand Up @@ -1425,7 +1412,6 @@ def test_generation_kwargs_dont_crash(attn_impl: str,

with get_precision_context('amp_bf16' if composer_device.name ==
'gpu' else 'fp32'):
# no padding in the input
no_padding_input_ids = torch.tensor([[11274, 16390, 11]])
no_padding_input_ids = composer_device.tensor_to_device(
no_padding_input_ids)
Expand All @@ -1442,7 +1428,6 @@ def test_generation_kwargs_dont_crash(attn_impl: str,


@pytest.mark.gpu
@pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton'])
@pytest.mark.parametrize('pos_emb_config', [{
'alibi': False,
'rope': False
Expand Down Expand Up @@ -1470,12 +1455,8 @@ def test_generation_kwargs_dont_crash(attn_impl: str,
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_model_to(attention_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
def test_model_to(pos_emb_config: dict, tie_word_embeddings: bool):
# test that moving the model to diff devices and dtypes in diff ways does not break the model
if pos_emb_config['alibi'] and attention_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')

if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(f'dail implementation of rope requires flash attention 2.')
Expand All @@ -1490,7 +1471,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict,
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
'attn_impl': attention_impl,
'attn_impl': 'torch',
**pos_emb_config,
},
init_config={
Expand All @@ -1514,17 +1495,15 @@ def test_model_to(attention_impl: str, pos_emb_config: dict,
mpt = mpt.to('cpu')

# verify the model still works
if attention_impl == 'torch' and not (
pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'):
if not (pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'):
with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True):
_ = mpt(input_ids.to('cpu'),
attention_mask=attention_mask.to('cpu'))

mpt = mpt.float()

# verify the model still works
if attention_impl == 'torch' and not (
pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'):
if not (pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'):
_ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu'))

mpt = mpt.to(0) # move to rank0
Expand Down Expand Up @@ -1586,16 +1565,11 @@ def test_alibi_vs_hf():
'factor': 1.0,
},
}])
@pytest.mark.parametrize('output_attentions', [True, False])
@pytest.mark.parametrize('output_hidden_states', [True, False])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_output_attentions_and_output_hidden_states(
attn_impl: str, pos_emb_config: dict, output_attentions: bool,
output_hidden_states: bool, tie_word_embeddings: bool):
# Test that model forward with output_attentions_and_output_hidden_states
attn_impl: str, pos_emb_config: dict):
if pos_emb_config['alibi'] and attn_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')
if output_attentions and attn_impl in ['flash', 'triton']:
if attn_impl in ['flash', 'triton']:
pytest.skip(f'output_attentions only implemented with torch attention.')
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
Expand Down Expand Up @@ -1624,7 +1598,7 @@ def test_forward_with_output_attentions_and_output_hidden_states(
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
tie_word_embeddings=True,
)
mpt = MPTForCausalLM(hf_config)
mpt = composer_device.module_to_device(mpt)
Expand All @@ -1637,20 +1611,16 @@ def test_forward_with_output_attentions_and_output_hidden_states(
attention_mask = torch.tensor([[1, 1, 1]]).bool()
attention_mask = composer_device.tensor_to_device(attention_mask)

# start with passing the first three tokens through
outputs = mpt(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_attentions=True,
output_hidden_states=True,
)

if output_attentions:
assert len(outputs.attentions) == n_layers
assert all(
attn.shape == (1, 4, 3, 3) for attn in outputs.attentions)
if output_hidden_states:
assert len(outputs.hidden_states) == n_layers + 1
assert len(outputs.attentions) == n_layers
assert all(attn.shape == (1, 4, 3, 3) for attn in outputs.attentions)
assert len(outputs.hidden_states) == n_layers + 1


@pytest.mark.gpu
Expand All @@ -1663,10 +1633,6 @@ def test_hf_init(tmp_path: pathlib.Path,
if not torch.cuda.device_count() >= world_size:
pytest.skip(f'This test requires {world_size} GPUs.')

torch.cuda.empty_cache()
gc.collect() #just in case
torch.cuda.synchronize()

test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
test_cfg.device = torch.cuda.current_device()

Expand Down

0 comments on commit d3d3bfe

Please sign in to comment.