From b432889256bf6c127b0dcc45c6c32fb0834ddecc Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 11 Jan 2024 22:43:41 +0900 Subject: [PATCH] feat: enable trl's autounwrap (#1060) * feat: test trl's autounwrap * fix: add check for adapter * feat: add config to disable autounwrap * chore: fix lint --- .vscode/launch.json | 2 +- devtools/README.md | 2 +- docs/debugging.md | 8 ++++---- docs/rlhf.md | 9 +++++++++ src/axolotl/train.py | 13 +++++++++---- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index ff4d63924d..ec1914dc8c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,7 +11,7 @@ "request": "launch", "args": [ "-m", "axolotl.cli.train", "dev_sharegpt.yml", - // The flags below simplify debugging by overriding the axolotl config + // The flags below simplify debugging by overriding the axolotl config // with the debugging tips above. Modify as needed. "--dataset_processes=1", // limits data preprocessing to one process "--max_steps=1", // limits training to just one step diff --git a/devtools/README.md b/devtools/README.md index 3b5d11e227..1d727ed8bb 100644 --- a/devtools/README.md +++ b/devtools/README.md @@ -1 +1 @@ -This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information. \ No newline at end of file +This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information. diff --git a/docs/debugging.md b/docs/debugging.md index 48459bc67d..f40b12dd4b 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -30,13 +30,13 @@ While debugging it's helpful to simplify your test scenario as much as possible. 3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0). 4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings. - `micro_batch_size: 1` - - `max_steps: 1` + - `max_steps: 1` - `val_set_size: 0` 5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging. - Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`. - HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s). - **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.** - + ## Debugging with VSCode @@ -74,7 +74,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler "request": "launch", "args": [ "-m", "axolotl.cli.train", "dev_sharegpt.yml", - // The flags below simplify debugging by overriding the axolotl config + // The flags below simplify debugging by overriding the axolotl config // with the debugging tips above. Modify as needed. "--dataset_processes=1", // limits data preprocessing to one process "--max_steps=1", // limits training to just one step @@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler - The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`. - The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch: - - `./devtools/temp_debug/axolotl_outputs` + - `./devtools/temp_debug/axolotl_outputs` - `./devtools/temp_debug/.hf-cache/datasets` >[!Tip] diff --git a/docs/rlhf.md b/docs/rlhf.md index 371a40dbf7..9957eb3a6f 100644 --- a/docs/rlhf.md +++ b/docs/rlhf.md @@ -33,3 +33,12 @@ datasets: ```yaml rl: ipo ``` + +#### Trl autounwrap for peft + +Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config. + +```yaml +# load ref model when adapter training. +rl_adapter_ref_model: true +``` diff --git a/src/axolotl/train.py b/src/axolotl/train.py index cf3382c896..36e598de43 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -63,10 +63,15 @@ def train( model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model_ref = None if cfg.rl: - # load the model again for model_ref/baseline - model_ref, _ = load_model( - cfg, tokenizer, inference=cli_args.inference, reference_model=True - ) + if cfg.adapter and not cfg.rl_adapter_ref_model: + # use built-in trl autounwrap + LOG.debug("Passing model_ref: None to RL trainer") + model_ref = None # explicit setting to None + else: + # load the model again for model_ref/baseline + model_ref, _ = load_model( + cfg, tokenizer, inference=cli_args.inference, reference_model=True + ) safe_serialization = cfg.save_safetensors is True