Skip to content

Commit

Permalink
Revert "fix?"
Browse files Browse the repository at this point in the history
This reverts commit 259cc76.
  • Loading branch information
dakinggg committed Sep 25, 2024
1 parent 259cc76 commit e27bb7b
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2551,17 +2551,17 @@ def test_hf_init(
betas=(0.9, 0.99),
)

model = BaseHuggingFaceModel(str(save_path), tokenizer)

prepare_fsdp_module(
model.model,
model,
optimizer,
FSDPConfig(**fsdp_config),
precision,
device,
False,
)

model = BaseHuggingFaceModel(model, tokenizer)

batch = gen_random_batch(batch_size, test_cfg)

original_params = next(model.parameters()).clone().data
Expand All @@ -2579,10 +2579,7 @@ def test_hf_init(


@pytest.mark.gpu
def test_head_dim_8_flash_mqa_attn(
tmp_path: pathlib.Path,
batch_size: int = 2,
):
def test_head_dim_8_flash_mqa_attn(batch_size: int = 2):
test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
test_cfg.device = torch.cuda.current_device()

Expand Down Expand Up @@ -2611,9 +2608,8 @@ def test_head_dim_8_flash_mqa_attn(
)

mpt = MPTForCausalLM(hf_config)
mpt.save_pretrained(tmp_path)

model = BaseHuggingFaceModel(str(tmp_path), tokenizer, shift_labels=True)
model = BaseHuggingFaceModel(mpt, tokenizer, shift_labels=True)

model = model.to(test_cfg.device)
batch = gen_random_batch(batch_size, test_cfg)
Expand Down

0 comments on commit e27bb7b

Please sign in to comment.