Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix optimizer reset for relora sft #1414

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ pre-commit
black
mypy
types-requests
tbparse
1 change: 1 addition & 0 deletions requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-xdist
pytest-retry
pytest-sugar
tbparse
36 changes: 23 additions & 13 deletions src/axolotl/monkeypatch/relora.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,33 @@ def reset_optimizer(
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
prune_ratio: float = 0.9,
optimizer_magnitude_pruning: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
# pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
n_zeros = 0
n_total = 0

optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state

for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
for group in optimizer.param_groups:
for param in group["params"]:
state = optimizer_state[param]
for key, value in state.items():
if key not in optimizer_state_keys:
continue
if torch.is_tensor(value):
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc

_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
Expand Down Expand Up @@ -129,6 +136,9 @@ def on_step_begin(

if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
if "8bit" in args.optim.lower():
optimizer_state_keys.append("state1")
optimizer_state_keys.append("state2")
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")

Expand Down Expand Up @@ -160,7 +170,7 @@ def on_step_begin(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
optimizer_magnitude_pruning=args.relora_prune_ratio,
)

if self.quantized:
Expand Down
56 changes: 40 additions & 16 deletions tests/e2e/test_relora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import unittest
from pathlib import Path

from tbparse import SummaryReader

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir
from .utils import most_recent_subdir, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand All @@ -29,41 +31,63 @@ def test_relora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 25,
"relora_warmup_steps": 5,
"relora_anneal_steps": 5,
"relora_steps": 100,
"relora_warmup_steps": 20,
"relora_anneal_steps": 10,
"relora_prune_ratio": 0.9,
"relora_cpu_offload": True,
"val_set_size": 0.0,
"special_tokens": {},
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"warmup_steps": 15,
"warmup_steps": 20,
"num_epochs": 2,
"max_steps": 51, # at least 2x relora_steps
"micro_batch_size": 4,
"max_steps": 205, # at least 2x relora_steps
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
assert (
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()

tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name
assert df.value.values[-1] < 0.2, "grad_norm is too high"
Loading