Skip to content

Commit

Permalink
migrate lora_ to peft_
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 28, 2023
1 parent 2b3eca1 commit 8d8c3bd
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 50 deletions.
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora
```yaml
adapter: lora # qlora or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
peft_r: 8
peft_alpha: 16
peft_dropout: 0.05
peft_target_modules:
- q_proj
- v_proj
```
Expand Down Expand Up @@ -502,26 +502,23 @@ adapter: lora
# if you already have a lora model trained that you want to load, put that here
# lora hyperparameters
peft_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
peft_r: 8
peft_alpha: 16
peft_dropout: 0.05
peft_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
# - gate_proj
# - down_proj
# - up_proj
lora_target_linear: # if true, will target all linear layers
lora_modules_to_save:
peft_target_linear: # if true, will target all linear layers
peft_modules_to_save:
# - embed_tokens
# - lm_head
lora_out_dir:
lora_fan_in_fan_out: false
ia3_target_modules: # target modules for IA3, for llama, k, v, and down projections
ia3_feedforward_modules: # ffn modules for IA3, for llama down projection
ia3_fan_in_fan_out:
peft_fan_in_fan_out: false
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
# ReLoRA configuration
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
Expand Down
6 changes: 3 additions & 3 deletions examples/llama-2/ia3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ pad_to_sequence_len: true

adapter: ia3
peft_model_dir:
ia3_target_modules:
peft_target_modules:
- k_proj
- v_proj
- down_proj
ia3_feedforward_modules:
peft_feedforward_modules:
- down_proj
ia3_fan_in_fan_out: false
peft_fan_in_fan_out: false

wandb_project:
wandb_entity:
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ def normalize_config(cfg):

log_gpu_memory_usage(LOG, "baseline", cfg.device)

if cfg.adapter is not None:
for key in list(cfg.keys()):
if key.startswith("lora_"):
new_key = key.replace("lora_", "peft_")
LOG.warning(
PendingDeprecationWarning(
f"{key} soon to be deprecated. please use {new_key}"
)
)
cfg[new_key] = cfg[key]
del cfg[key]


def validate_config(cfg):
if is_torch_bf16_gpu_available():
Expand Down
64 changes: 32 additions & 32 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM",
)

if cfg.peft_model_dir or cfg.lora_model_dir:
if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir or cfg.lora_model_dir,
cfg.peft_model_dir,
torch_dtype=torch.float16,
)
else:
Expand All @@ -469,37 +469,37 @@ def load_llama_adapter(model, cfg):

def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
peft_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
peft_module_names.add(names[0] if len(names) == 1 else names[-1])

if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
if "lm_head" in peft_module_names: # needed for 16-bit
peft_module_names.remove("lm_head")

return list(lora_module_names)
return list(peft_module_names)


def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

from peft import LoraConfig, PeftModel, get_peft_model

lora_target_modules = list(cfg.lora_target_modules or [])
peft_target_modules = list(cfg.peft_target_modules or [])

if cfg.lora_target_linear:
if cfg.peft_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))

lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
peft_target_modules = list(set(peft_target_modules + linear_names))

peft_config = LoraConfig(
r=cfg.peft_r,
lora_alpha=cfg.peft_alpha,
target_modules=peft_target_modules,
lora_dropout=cfg.peft_dropout,
fan_in_fan_out=cfg.peft_fan_in_fan_out,
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
bias="none",
task_type="CAUSAL_LM",
)
Expand All @@ -512,30 +512,30 @@ def load_lora(model, cfg, inference=False):
is_trainable=(not inference),
)
else:
model = get_peft_model(model, lora_config)
model = get_peft_model(model, peft_config)

model.print_trainable_parameters()

return model, lora_config
return model, peft_config


def load_ia3(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

from peft import IA3Config, PeftModel, get_peft_model

ia3_config_kwargs = {}
if cfg.ia3_init_ia3_weights is not None:
ia3_config_kwargs["init_ia3_weights"] = cfg.ia3_init_ia3_weights
if cfg.ia3_fan_in_fan_out is not None:
ia3_config_kwargs["fan_in_fan_out"] = cfg.ia3_fan_in_fan_out
peft_config_kwargs = {}
if cfg.peft_init_ia3_weights is not None:
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
if cfg.peft_fan_in_fan_out is not None:
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out

ia3_config = IA3Config(
target_modules=cfg.ia3_target_modules,
feedforward_modules=cfg.ia3_feedforward_modules,
modules_to_save=cfg.ia3_modules_to_save,
peft_config = IA3Config(
target_modules=cfg.peft_target_modules,
feedforward_modules=cfg.peft_feedforward_modules,
modules_to_save=cfg.peft_modules_to_save,
task_type="CAUSAL_LM",
**ia3_config_kwargs,
**peft_config_kwargs,
)

if cfg.peft_model_dir:
Expand All @@ -546,8 +546,8 @@ def load_ia3(model, cfg, inference=False):
is_trainable=(not inference),
)
else:
model = get_peft_model(model, ia3_config)
model = get_peft_model(model, peft_config)

model.print_trainable_parameters()

return model, ia3_config
return model, peft_config
48 changes: 48 additions & 0 deletions tests/test_cfg_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Module for testing the validation module"""

import logging
import unittest
from typing import Optional

import pytest

from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault


class NormalizationTest(unittest.TestCase):
"""
Test the cfg normalization module
"""

_caplog: Optional[pytest.LogCaptureFixture] = None

@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog

def test_lora_to_peft(self):
base_cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"base_model": "NousResearch/Llama-2-7b-hf",
"base_model_config": "NousResearch/Llama-2-7b-hf",
}
)
cfg = base_cfg | DictDefault(
{
"adapter": "lora",
"lora_r": 128,
"lora_alpha": 64,
}
)
with self._caplog.at_level(logging.WARNING):
normalize_config(cfg)
assert any(
"soon to be deprecated. please use peft_" in record.message
for record in self._caplog.records
)

assert cfg.peft_r == 128
assert cfg.peft_alpha == 64

0 comments on commit 8d8c3bd

Please sign in to comment.