diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 16f02d63e2..92effffdd8 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -2551,10 +2551,8 @@ def test_hf_init( betas=(0.9, 0.99), ) - model = BaseHuggingFaceModel(str(save_path), tokenizer) - prepare_fsdp_module( - model.model, + model, optimizer, FSDPConfig(**fsdp_config), precision, @@ -2562,6 +2560,8 @@ def test_hf_init( False, ) + model = BaseHuggingFaceModel(model, tokenizer) + batch = gen_random_batch(batch_size, test_cfg) original_params = next(model.parameters()).clone().data @@ -2579,10 +2579,7 @@ def test_hf_init( @pytest.mark.gpu -def test_head_dim_8_flash_mqa_attn( - tmp_path: pathlib.Path, - batch_size: int = 2, -): +def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') test_cfg.device = torch.cuda.current_device() @@ -2611,9 +2608,8 @@ def test_head_dim_8_flash_mqa_attn( ) mpt = MPTForCausalLM(hf_config) - mpt.save_pretrained(tmp_path) - model = BaseHuggingFaceModel(str(tmp_path), tokenizer, shift_labels=True) + model = BaseHuggingFaceModel(mpt, tokenizer, shift_labels=True) model = model.to(test_cfg.device) batch = gen_random_batch(batch_size, test_cfg)