Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 18, 2023
1 parent 3a863ca commit 84220de
Showing 1 changed file with 23 additions and 40 deletions.
63 changes: 23 additions & 40 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from composer.core.precision import Precision, get_precision_context
from composer.optim import DecoupledAdamW
from composer.trainer.dist_strategy import prepare_fsdp_module
from composer.utils import dist, get_device, reproducibility
from composer.utils import dist, get_device
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel,
Expand Down Expand Up @@ -550,20 +550,18 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
device = 'gpu' if torch.cuda.is_available() else 'cpu'

# Test that different placement of padding does not affect the output.
alibi = pos_emb_config['alibi']
if alibi and attention_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')

rope = pos_emb_config['rope']
if rope and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if rope and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')

composer_device = get_device(device)
composer_device = get_device(None)

hf_config = MPTConfig(
init_device='cpu',
Expand Down Expand Up @@ -775,21 +773,19 @@ def test_advanced_mask_building(attention_impl: str):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_generate(attention_impl: str, precision: str, pos_emb_config: dict,
tie_word_embeddings: bool):
device = 'gpu' if torch.cuda.is_available() else 'cpu'

# Test that generate works, and produces the same output with or without
# padding in the input.
if pos_emb_config['alibi'] and attention_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')

if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')
if attention_impl == 'torch' and precision == 'amp_bf16' and tie_word_embeddings == False:
pytest.skip(f'This test configuration has precision / sampling issues.')

composer_device = get_device(device)
composer_device = get_device(None)

hf_config = MPTConfig(
init_device='cpu',
Expand Down Expand Up @@ -1005,17 +1001,15 @@ def test_save_from_pretrained(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
device = 'gpu' if torch.cuda.is_available() else 'cpu'

# Tests that the result is the same with or without padding when using kv caching
if pos_emb_config['alibi'] and attn_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')
if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')

composer_device = get_device(device)
composer_device = get_device(None)

hf_config = MPTConfig(
init_device='cpu',
Expand Down Expand Up @@ -1144,19 +1138,17 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict,
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_cache(attn_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
device = 'gpu' if torch.cuda.is_available() else 'cpu'

# Test that model forward with and without the key-value cache produces the
# same output.
if pos_emb_config['alibi'] and attn_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')

if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')

composer_device = get_device(device)
composer_device = get_device(None)

hf_config = MPTConfig(
init_device='cpu',
Expand Down Expand Up @@ -1286,16 +1278,14 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict,
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
device = 'gpu' if torch.cuda.is_available() else 'cpu'

if pos_emb_config['alibi'] and attn_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')
if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')

composer_device = get_device(device)
composer_device = get_device(None)

hf_config = MPTConfig(
init_device='cpu',
Expand Down Expand Up @@ -1401,18 +1391,15 @@ def test_generation_kwargs_dont_crash(attn_impl: str,
generation_kwargs: Dict[str, Any],
pos_emb_config: dict,
tie_word_embeddings: bool):
device = 'gpu' if torch.cuda.is_available() else 'cpu'

if pos_emb_config['alibi'] and attn_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')

if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')
composer_device = get_device(device)
if device == 'gpu': # Switch deteminism off
torch.use_deterministic_algorithms(False)
composer_device = get_device(None)

hf_config = MPTConfig(
init_device='cpu',
d_model=128,
Expand Down Expand Up @@ -1446,8 +1433,6 @@ def test_generation_kwargs_dont_crash(attn_impl: str,
_ = mpt.generate(input_ids=no_padding_input_ids,
attention_mask=no_padding_attention_mask,
**generation_kwargs)
if device == 'gpu': # Switch deteminism back on
reproducibility.configure_deterministic_mode()


@pytest.mark.gpu
Expand Down Expand Up @@ -1614,19 +1599,17 @@ def test_alibi_vs_hf():
def test_forward_with_output_attentions_and_output_hidden_states(
attn_impl: str, pos_emb_config: dict, output_attentions: bool,
output_hidden_states: bool, tie_word_embeddings: bool):
device = 'gpu' if torch.cuda.is_available() else 'cpu'

# Test that model forward with output_attentions_and_output_hidden_states
if pos_emb_config['alibi'] and attn_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')
if output_attentions and attn_impl in ['flash', 'triton']:
pytest.skip(f'output_attentions only implemented with torch attention.')
if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')

composer_device = get_device(device)
composer_device = get_device(None)

n_layers = 2

Expand Down

0 comments on commit 84220de

Please sign in to comment.