Skip to content

Commit

Permalink
Add layers_to_transform for lora_config (#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
xzuyn authored Jan 16, 2024
1 parent f6fea8e commit 5299ad4
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,8 @@ lora_target_modules:
# - gate_proj
# - down_proj
# - up_proj
lora_target_linear: # If true, will target all linear layers
lora_target_linear: # If true, will target all linear modules
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers

# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def validate_config(cfg):
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")

if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
raise ValueError(
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
)

if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def load_lora(model, cfg, inference=False):
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,21 @@ def test_warmup_step_no_conflict(self):

validate_config(cfg)

def test_unfrozen_parameters_w_peft_layers_to_transform(self):
cfg = DictDefault(
{
"adapter": "lora",
"unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"],
"peft_layers_to_transform": [0, 1],
}
)

with pytest.raises(
ValueError,
match=r".*can have unexpected behavior*",
):
validate_config(cfg)


class ValidationCheckModelConfig(BaseValidation):
"""
Expand Down

0 comments on commit 5299ad4

Please sign in to comment.