Skip to content

Commit

Permalink
add e2e check for lora w/o flash attention for mixtral to check gate
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 12, 2024
1 parent 05f4555 commit ff2a221
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion tests/e2e/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TestMixtral(unittest.TestCase):
"""

@with_temp_dir
def test_qlora(self, temp_dir):
def test_qlora_w_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
Expand Down Expand Up @@ -68,6 +68,48 @@ def test_qlora(self, temp_dir):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_lora_wo_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": False,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 4,
"lora_alpha": 8,
"lora_dropout": 0.1,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
Expand Down

0 comments on commit ff2a221

Please sign in to comment.