Skip to content

Commit

Permalink
fix optimizer reset for relora sft (#1414)
Browse files Browse the repository at this point in the history
* fix optimizer reset

* set states to reset for 8bit optimizers and handle quantile runtime error for embeddings

* fix relora test to check grad_norm

* use flash attn for relora and tweak hyperparams for test

* fix messages field for test dataset
  • Loading branch information
winglian authored and bursteratom committed Dec 4, 2024
1 parent 1969fa3 commit 7aa5780
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 30 deletions.
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ pre-commit
black
mypy
types-requests
tbparse
1 change: 1 addition & 0 deletions requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-xdist
pytest-retry
pytest-sugar
tbparse
36 changes: 23 additions & 13 deletions src/axolotl/monkeypatch/relora.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,33 @@ def reset_optimizer(
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
prune_ratio: float = 0.9,
optimizer_magnitude_pruning: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
# pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
n_zeros = 0
n_total = 0

optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state

for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
for group in optimizer.param_groups:
for param in group["params"]:
state = optimizer_state[param]
for key, value in state.items():
if key not in optimizer_state_keys:
continue
if torch.is_tensor(value):
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc

_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
Expand Down Expand Up @@ -129,6 +136,9 @@ def on_step_begin(

if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
if "8bit" in args.optim.lower():
optimizer_state_keys.append("state1")
optimizer_state_keys.append("state2")
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")

Expand Down Expand Up @@ -160,7 +170,7 @@ def on_step_begin(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
optimizer_magnitude_pruning=args.relora_prune_ratio,
)

if self.quantized:
Expand Down
56 changes: 40 additions & 16 deletions tests/e2e/test_relora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import unittest
from pathlib import Path

from tbparse import SummaryReader

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir
from .utils import most_recent_subdir, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand All @@ -29,41 +31,63 @@ def test_relora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 25,
"relora_warmup_steps": 5,
"relora_anneal_steps": 5,
"relora_steps": 100,
"relora_warmup_steps": 20,
"relora_anneal_steps": 10,
"relora_prune_ratio": 0.9,
"relora_cpu_offload": True,
"val_set_size": 0.0,
"special_tokens": {},
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"warmup_steps": 15,
"warmup_steps": 20,
"num_epochs": 2,
"max_steps": 51, # at least 2x relora_steps
"micro_batch_size": 4,
"max_steps": 205, # at least 2x relora_steps
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
assert (
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()

tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name
assert df.value.values[-1] < 0.2, "grad_norm is too high"

0 comments on commit 7aa5780

Please sign in to comment.