From 1ef70312bad2989a3e36e3a9d1f5bcff6243d32a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 3 Dec 2024 08:58:23 -0500 Subject: [PATCH] fix optimizer reset for relora sft (#1414) * fix optimizer reset * set states to reset for 8bit optimizers and handle quantile runtime error for embeddings * fix relora test to check grad_norm * use flash attn for relora and tweak hyperparams for test * fix messages field for test dataset --- requirements-dev.txt | 1 - requirements-tests.txt | 1 + src/axolotl/monkeypatch/relora.py | 36 +++++++++++++------- tests/e2e/test_relora_llama.py | 56 ++++++++++++++++++++++--------- 4 files changed, 64 insertions(+), 30 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index dcc729d1b2..4b5df167b6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,3 @@ pre-commit black mypy types-requests -tbparse diff --git a/requirements-tests.txt b/requirements-tests.txt index 0022980e90..a13f739231 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -2,3 +2,4 @@ pytest pytest-xdist pytest-retry pytest-sugar +tbparse diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 9d246cb17f..3fda84b929 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -46,9 +46,10 @@ def reset_optimizer( *, reset_params: List[str], # where str is the key to a torch.nn.Parameter optimizer_state_keys: List[str], - prune_ratio: float = 0.9, + optimizer_magnitude_pruning: float = 0.9, ): - pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) + # pylint:disable=unused-argument + pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning) n_zeros = 0 n_total = 0 @@ -56,16 +57,22 @@ def reset_optimizer( if isinstance(optimizer, ZeroRedundancyOptimizer): optimizer_state = optimizer.optim.state - for param in reset_params: - param_state = optimizer_state[param] - if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer - continue - for key in optimizer_state_keys: - pruning_fn( - param_state[key] - ) # pruning fn has to be inplace to keep the same keys in the dict - n_total += param_state[key].numel() - n_zeros += torch.sum(param_state[key] == 0).item() + for group in optimizer.param_groups: + for param in group["params"]: + state = optimizer_state[param] + for key, value in state.items(): + if key not in optimizer_state_keys: + continue + if torch.is_tensor(value): + try: + pruning_fn(value) + n_total += value.numel() + n_zeros += torch.sum(value == 0).item() + except RuntimeError as exc: + if "quantile() input tensor is too large" in str(exc): + pass + else: + raise exc _zeroed = n_zeros / (1e-7 + n_total) * 100 LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}") @@ -129,6 +136,9 @@ def on_step_begin( if "adam" in args.optim.lower(): optimizer_state_keys = ["exp_avg", "exp_avg_sq"] + if "8bit" in args.optim.lower(): + optimizer_state_keys.append("state1") + optimizer_state_keys.append("state2") else: raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA") @@ -160,7 +170,7 @@ def on_step_begin( optimizer, reset_params=lora_params, optimizer_state_keys=optimizer_state_keys, - prune_ratio=args.relora_prune_ratio, + optimizer_magnitude_pruning=args.relora_prune_ratio, ) if self.quantized: diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py index 5de5db11b7..56c2204677 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/test_relora_llama.py @@ -7,13 +7,15 @@ import unittest from pathlib import Path +from tbparse import SummaryReader + from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import most_recent_subdir, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -29,36 +31,48 @@ def test_relora(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 2048, + "sample_packing": True, + "pad_to_sequence_len": True, + "flash_attention": True, "load_in_8bit": True, "adapter": "lora", - "lora_r": 32, + "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_modules": ["q_proj", "v_proj"], - "relora_steps": 25, - "relora_warmup_steps": 5, - "relora_anneal_steps": 5, + "relora_steps": 100, + "relora_warmup_steps": 20, + "relora_anneal_steps": 10, + "relora_prune_ratio": 0.9, "relora_cpu_offload": True, "val_set_size": 0.0, - "special_tokens": {}, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "chat_template": "chatml", "datasets": [ { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", }, ], - "warmup_steps": 15, + "warmup_steps": 20, "num_epochs": 2, - "max_steps": 51, # at least 2x relora_steps - "micro_batch_size": 4, + "max_steps": 205, # at least 2x relora_steps + "micro_batch_size": 2, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, - "optimizer": "adamw_torch", + "optimizer": "adamw_8bit", "lr_scheduler": "cosine", + "save_safetensors": True, + "use_tensorboard": True, } ) normalize_config(cfg) @@ -66,4 +80,14 @@ def test_relora(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + assert ( + Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors" + ).exists() + assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists() + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name + assert df.value.values[-1] < 0.2, "grad_norm is too high"