Skip to content

Commit

Permalink
Proper import checking
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Oct 24, 2023
1 parent 5bc5240 commit 79cf8d6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
try:
from peft import PeftModel, LoraConfig, get_peft_model
model_types = PeftModel, transformers.PreTrainedModel
_peft_installed = True

except ImportError:
_peft_installed = False
model_types = transformers.PreTrainedModel,

__all__ = ['ComposerHFCausalLM']
Expand All @@ -42,7 +44,7 @@

def print_trainable_parameters(model: nn.Module) -> None:
# Prints the number of trainable parameters in the model.
if PeftModel not in model_types:
if not _peft_installed:
raise ImportError('PEFT not installed. Run pip install -e ".[gpu,peft]"')
trainable_params = 0
all_param = 0
Expand Down Expand Up @@ -241,7 +243,7 @@ def __init__(self, om_model_config: Union[DictConfig,
# if om_model_config includes lora and peft is installed, add lora modules
lora_cfg = om_model_config.get("lora", None)
if lora_cfg is not None:
if PeftModel not in model_types:
if not _peft_installed:
raise ImportError(
'cfg.model.lora is given but PEFT not installed. Run pip install -e ".[gpu,peft]"'
)
Expand Down

0 comments on commit 79cf8d6

Please sign in to comment.