From d3efaf8b1473b4bcf6371b0f2fb92e27708b7068 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Nov 2024 19:10:14 -0500 Subject: [PATCH] support for schedule free and e2e ci smoke test (#2066) [skip ci] * support for schedule free and e2e ci smoke test * set default lr scheduler to constant in test * ignore duplicate code * fix quotes for config/dict --- requirements.txt | 1 + tests/e2e/test_optimizers.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/requirements.txt b/requirements.txt index aab4d0eade..f352fecda9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,3 +53,4 @@ immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 torchao==0.5.0 +schedulefree==1.3.0 diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 3b68ec5ad5..b9fa368f6f 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -108,3 +108,37 @@ def test_adopt_adamw(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_fft_schedule_free_adamw(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM-135M", + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "schedule_free_adamw", + "lr_scheduler": "constant", + "save_safetensors": True, + } + ) + # pylint: disable=duplicate-code + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()