Skip to content

Commit

Permalink
fix the parameter name, was using the old disco name
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 12, 2024
1 parent 1003283 commit 4f0d078
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tests/e2e/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def test_qlora_w_fa2(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert model.base_model.model.model.layers[0].mlp.gate.dtype == torch.float32
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
Expand Down Expand Up @@ -110,7 +113,10 @@ def test_qlora_wo_fa2(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert model.base_model.model.model.layers[0].mlp.gate.dtype == torch.float32
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
Expand Down Expand Up @@ -152,7 +158,10 @@ def test_16bit_lora_w_fa2(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert model.base_model.model.model.layers[0].mlp.gate.dtype == torch.float32
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
Expand Down Expand Up @@ -194,7 +203,10 @@ def test_16bit_lora_wo_fa2(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert model.base_model.model.model.layers[0].mlp.gate.dtype == torch.float32
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
Expand Down

0 comments on commit 4f0d078

Please sign in to comment.