diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 0fd4e8b521..ac3c6d0693 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -66,7 +66,7 @@ def test_dpo_lora(self, temp_dir): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() @with_temp_dir def test_kto_pair_lora(self, temp_dir): @@ -110,7 +110,7 @@ def test_kto_pair_lora(self, temp_dir): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() @with_temp_dir def test_ipo_lora(self, temp_dir): @@ -154,4 +154,4 @@ def test_ipo_lora(self, temp_dir): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()