Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 11, 2023
1 parent 6c96bd1 commit e740386
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 34 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ def __init__(
f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`'
)
# tie_word_embeddings is set in Huggingface's PretrainedConfig __init__
kwargs['tie_word_embeddings'] = tie_word_embeddings
super().__init__(**kwargs)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

self._validate_config()

Expand Down
41 changes: 29 additions & 12 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,21 @@ def test_callback_inits_with_defaults():

@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool,
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
def test_huggingface_conversion_callback(
model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str], log_to_mlflow: bool,
hf_save_interval: str, save_interval: str, max_duration: str,
expected_hf_checkpoints: int, expected_normal_checkpoints: int):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand Down Expand Up @@ -298,9 +299,11 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
'attn_impl': 'torch',
},
'loss_fn': 'torch_crossentropy',
'tie_word_embeddings': tie_word_embeddings,
}
tokenizer_name = 'EleutherAI/gpt-neox-20b'
elif model == 'neo':
assert tie_word_embeddings is None
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'EleutherAI/gpt-neo-125M',
Expand All @@ -313,6 +316,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
}
tokenizer_name = 'EleutherAI/gpt-neo-125M'
elif model == 'llama2':
assert tie_word_embeddings is None
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
Expand Down Expand Up @@ -489,19 +493,26 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
delete_transformers_cache()


@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
def test_convert_and_generate(model: str, tmp_path: pathlib.Path):
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
def test_convert_and_generate(model: str, tie_word_embeddings: bool,
tmp_path: pathlib.Path):
delete_transformers_cache()

om_cfg = None
if model == 'mpt':
om_cfg = get_config(
conf_path='scripts/train/yamls/pretrain/testing.yaml')
om_cfg['tie_word_embeddings'] = tie_word_embeddings
elif model == 'neo':
assert tie_word_embeddings is None
om_cfg = get_config(
conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml')
om_cfg['model']['config_overrides']['hidden_size'] = 36
elif model == 'llama2':
assert tie_word_embeddings is None
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
Expand Down Expand Up @@ -561,11 +572,14 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path):


@pytest.mark.gpu
def test_convert_and_generate_triton(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_triton(tie_word_embeddings: str,
tmp_path: pathlib.Path):
delete_transformers_cache()

cfg = get_config()
cfg['model']['init_device'] = 'cpu'
cfg['tie_word_embeddings'] = tie_word_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
'EleutherAI/gpt-neox-20b')
model = ComposerMPTCausalLM(cfg['model'], tokenizer)
Expand Down Expand Up @@ -600,7 +614,9 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path):
delete_transformers_cache()


def test_convert_and_generate_meta(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_meta(tie_word_embeddings: str,
tmp_path: pathlib.Path):
delete_transformers_cache()

from composer.utils import dist
Expand All @@ -610,6 +626,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path):
om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')

om_cfg['model']['init_device'] = 'cpu'
om_cfg['tie_word_embeddings'] = tie_word_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
om_cfg.tokenizer.name)
original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name](
Expand Down
69 changes: 57 additions & 12 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ def test_opt_wrapping():

@pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys())
@pytest.mark.parametrize('no_bias', [False, True])
def test_mpt_creation(norm_type: str, no_bias: bool):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
# Test that the config constructs the model as expected.
hf_config = MPTConfig(
init_device='cpu',
Expand All @@ -482,6 +483,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool):
},
norm_type=norm_type,
no_bias=no_bias,
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)

Expand All @@ -493,6 +495,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool):

assert mpt.transformer.wte.weight.shape == torch.Size(
[hf_config.vocab_size, hf_config.d_model])
if not tie_word_embeddings:
assert mpt.lm_head is not None
assert mpt.lm_head.weight.shape == mpt.transformer.wte.weight.shape
assert mpt.transformer.wpe.weight.shape == torch.Size(
[hf_config.max_seq_len, hf_config.d_model])
assert mpt.transformer.emb_drop.p == 0.1
Expand Down Expand Up @@ -544,8 +549,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool):
'factor': 1.0,
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_padding(attention_impl: str, device: str,
pos_emb_config: dict):
pos_emb_config: dict, tie_word_embeddings: bool):
# Test that different placement of padding does not affect the output.
if not torch.cuda.is_available() and device == 'gpu':
pytest.skip(
Expand Down Expand Up @@ -580,6 +586,7 @@ def test_forward_with_padding(attention_impl: str, device: str,
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
mpt.eval()
Expand Down Expand Up @@ -766,7 +773,9 @@ def test_advanced_mask_building(attention_impl: str):
'factor': 1.0,
},
}])
def test_generate(attention_impl: str, device: str, pos_emb_config: dict):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_generate(attention_impl: str, device: str, pos_emb_config: dict,
tie_word_embeddings: bool):
# Test that generate works, and produces the same output with or without
# padding in the input.
if not torch.cuda.is_available() and device == 'gpu':
Expand Down Expand Up @@ -796,10 +805,15 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict):
'attn_impl': attention_impl,
**pos_emb_config,
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
mpt.eval()
if not tie_word_embeddings:
assert mpt.lm_head is not None
with torch.no_grad():
mpt.lm_head.weight.copy_(mpt.transformer.wte.weight)
mpt = composer_device.module_to_device(mpt)
mpt.eval()

# padding on the left of the input
left_padding_input_ids = torch.tensor(
Expand Down Expand Up @@ -861,8 +875,9 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict):
@pytest.mark.gpu
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_cache', [False, True])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int,
use_cache: bool):
use_cache: bool, tie_word_embeddings: bool):
if not torch.cuda.is_available():
pytest.skip(f'This test requires CUDA to be available.')
if not torch.cuda.device_count() >= world_size:
Expand All @@ -882,6 +897,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int,
'attn_impl': 'torch',
},
use_cache=use_cache,
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
mpt.save_pretrained(save_path)
Expand Down Expand Up @@ -938,7 +954,9 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
torch.testing.assert_close(p1, p2)


def test_save_from_pretrained(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_save_from_pretrained(tie_word_embeddings: bool,
tmp_path: pathlib.Path):
# Test that MPT can be used with the HuggingFace
# save_pretrained/from_pretrained api.
hf_config = MPTConfig(
Expand All @@ -953,10 +971,12 @@ def test_save_from_pretrained(tmp_path: pathlib.Path):
attn_config={
'attn_impl': 'torch',
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)

mpt.save_pretrained(tmp_path / 'test-save-pretrained')
print(tmp_path / 'test-save-pretrained')
mpt2 = MPTForCausalLM.from_pretrained(tmp_path / 'test-save-pretrained')

check_hf_model_equivalence(mpt, mpt2)
Expand Down Expand Up @@ -994,8 +1014,10 @@ def test_save_from_pretrained(tmp_path: pathlib.Path):
'factor': 1.0,
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_cache_and_padding(attn_impl: str, device: str,
pos_emb_config: dict):
pos_emb_config: dict,
tie_word_embeddings: bool):
# Tests that the result is the same with or without padding when using kv caching
if not torch.cuda.is_available() and device == 'gpu':
pytest.skip(
Expand Down Expand Up @@ -1028,6 +1050,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str,
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
)

mpt = MPTForCausalLM(hf_config)
Expand Down Expand Up @@ -1133,7 +1156,9 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str,
'factor': 1.0,
},
}])
def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict,
tie_word_embeddings: bool):
# Test that model forward with and without the key-value cache produces the
# same output.
if not torch.cuda.is_available() and device == 'gpu':
Expand Down Expand Up @@ -1168,8 +1193,13 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict):
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
if not tie_word_embeddings:
assert mpt.lm_head is not None
with torch.no_grad():
mpt.lm_head.weight.copy_(mpt.transformer.wte.weight)
mpt = composer_device.module_to_device(mpt)
mpt.eval()

Expand Down Expand Up @@ -1274,8 +1304,9 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict):
'factor': 1.0,
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_generate_with_past_kv(attn_impl: str, device: str,
pos_emb_config: dict):
pos_emb_config: dict, tie_word_embeddings: bool):
if not torch.cuda.is_available() and device == 'gpu':
pytest.skip(
f'This test requires CUDA to be available in order to run with {attn_impl} attention.'
Expand Down Expand Up @@ -1307,8 +1338,13 @@ def test_generate_with_past_kv(attn_impl: str, device: str,
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
if not tie_word_embeddings:
assert mpt.lm_head is not None
with torch.no_grad():
mpt.lm_head.weight.copy_(mpt.transformer.wte.weight)
mpt = composer_device.module_to_device(mpt)
mpt.eval()

Expand Down Expand Up @@ -1386,9 +1422,11 @@ def test_generate_with_past_kv(attn_impl: str, device: str,
'factor': 1.0,
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_generation_kwargs_dont_crash(attn_impl: str, device: str,
generation_kwargs: Dict[str, Any],
pos_emb_config: dict):
pos_emb_config: dict,
tie_word_embeddings: bool):
if not torch.cuda.is_available() and device == 'gpu':
pytest.skip(
f'This test requires CUDA to be available in order to run with {attn_impl} attention.'
Expand Down Expand Up @@ -1417,6 +1455,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str,
**pos_emb_config,
},
use_cache=True,
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
mpt = composer_device.module_to_device(mpt)
Expand Down Expand Up @@ -1467,7 +1506,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str,
'factor': 1.0,
},
}])
def test_model_to(attention_impl: str, pos_emb_config: dict):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_model_to(attention_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
# test that moving the model to diff devices and dtypes in diff ways does not break the model
if not torch.cuda.is_available():
pytest.skip(
Expand Down Expand Up @@ -1498,6 +1539,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict):
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
mpt = mpt.bfloat16()
Expand Down Expand Up @@ -1600,9 +1642,11 @@ def test_alibi_vs_hf():
}])
@pytest.mark.parametrize('output_attentions', [True, False])
@pytest.mark.parametrize('output_hidden_states', [True, False])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_forward_with_output_attentions_and_output_hidden_states(
attn_impl: str, device: str, pos_emb_config: dict,
output_attentions: bool, output_hidden_states: bool):
output_attentions: bool, output_hidden_states: bool,
tie_word_embeddings: bool):
# Test that model forward with output_attentions_and_output_hidden_states
if not torch.cuda.is_available() and device == 'gpu':
pytest.skip(
Expand Down Expand Up @@ -1639,6 +1683,7 @@ def test_forward_with_output_attentions_and_output_hidden_states(
'name': 'baseline_',
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)
mpt = composer_device.module_to_device(mpt)
Expand Down
Loading

0 comments on commit e740386

Please sign in to comment.