Skip to content

Commit

Permalink
Simplify LoRA validation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Oct 23, 2023
1 parent 2f59377 commit 5bc5240
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import logging
import os
from typing import Mapping, Union
from typing import Mapping, Union, Dict

# required for loading a python model into composer
import torch
import transformers
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
Expand All @@ -29,11 +30,11 @@
from llmfoundry.models.utils import init_empty_weights

try:
from peft.peft_model import PeftModel, LoraConfig, get_peft_model
from peft import PeftModel, LoraConfig, get_peft_model
model_types = PeftModel, transformers.PreTrainedModel

except ImportError:
model_types = transformers.PreTrainedModel
model_types = transformers.PreTrainedModel,

__all__ = ['ComposerHFCausalLM']

Expand All @@ -53,41 +54,41 @@ def print_trainable_parameters(model: nn.Module) -> None:
f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.3g}"
)

def validate_lora_config(cfg: DictConfig):
# Validate lora config within the same function
lora_config = cfg.model.get('lora', None)
if lora_config is not None and isinstance(lora_config, (dict, DictConfig)):
args = lora_config.get('args', None)
if args is not None and isinstance(args, (dict, DictConfig)):
r = args.get('r', None)
if r is None or not isinstance(r, int):
raise ValueError('lora r must be an integer')

lora_alpha = args.get('lora_alpha', None)
if lora_alpha is None or not isinstance(lora_alpha, (float, int)):
raise ValueError('lora_alpha must be a float/int')

target_modules = args.get('target_modules', None)
if target_modules is None or not isinstance(target_modules,
(list, ListConfig)):
raise ValueError('target_modules must be a list')
elif len(target_modules) == 0:
raise ValueError('target_modules is an empty list')
else:
for module in target_modules:
if not isinstance(module, str):
raise ValueError(
'target_modules must be a list of strings')
lora_dropout = args.get('lora_dropout', None)
if lora_dropout is None or not isinstance(lora_dropout, float):
raise ValueError('lora_dropout must be a float')
def validate_lora_config(lora_cfg: DictConfig):
for arg in ['r', 'lora_alpha', 'lora_dropout', 'target_modules', 'task_type']:
if arg not in lora_cfg:
raise ValueError(f'model.lora.{arg} must be specified')

r = lora_cfg['r']
if not isinstance(r, int) or r <= 0:
raise ValueError('LoRA rank (model.lora.r) must be a positive integer')

lora_alpha = lora_cfg['lora_alpha']
if not isinstance(lora_alpha, (float, int)):
raise ValueError('lora_alpha must be a float/int')

target_modules = lora_cfg['target_modules']
if not isinstance(target_modules, (list, ListConfig)):
raise ValueError('target_modules must be a list')
if len(target_modules) == 0:
raise ValueError('target_modules must be non-empty list')
if not all(isinstance(module, str) for module in target_modules):
raise ValueError('target_modules must be a list of strings')

lora_dropout = lora_cfg['lora_dropout']
if not isinstance(lora_dropout, float):
raise ValueError('lora_dropout must be a float')

task_type = lora_cfg['task_type']
if not isinstance(task_type, str):
raise ValueError('task_type must be a string')

task_type = args.get('task_type', None)
if task_type is None or not isinstance(task_type, str):
raise ValueError('task_type must be a string')
print('=' * 20 + 'LoRa is enabled!' + '=' * 20)
print('=' * 20 + 'LoRA is enabled!' + '=' * 20)


def lora_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]:
return {param: tensor for param, tensor in model.state_dict().items()
if '.lora_' in param}


class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
Expand Down Expand Up @@ -248,7 +249,7 @@ def __init__(self, om_model_config: Union[DictConfig,
validate_lora_config(lora_cfg)

print("Building Lora config...")
lora_cfg = LoraConfig(**lora_cfg.args)
lora_cfg = LoraConfig(**lora_cfg)
print("Lora config built.")
print("Adding Lora modules...")
model = get_peft_model(model, lora_cfg)
Expand Down

0 comments on commit 5bc5240

Please sign in to comment.