diff --git a/requirements.txt b/requirements.txt index 76c3273fcf..8474e04255 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,3 +53,4 @@ immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 torchao==0.5.0 +schedulefree==1.3.0 diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 3b68ec5ad5..b9fa368f6f 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -108,3 +108,37 @@ def test_adopt_adamw(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_fft_schedule_free_adamw(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM-135M", + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "schedule_free_adamw", + "lr_scheduler": "constant", + "save_safetensors": True, + } + ) + # pylint: disable=duplicate-code + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()