From 79cf8d67e3917b8ca8fbec3291bd31280cb6a5a1 Mon Sep 17 00:00:00 2001 From: Jose Javier <26491792+josejg@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:51:57 +0000 Subject: [PATCH] Proper import checking --- llmfoundry/models/hf/hf_causal_lm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d09f9a4419..96587e1398 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -32,8 +32,10 @@ try: from peft import PeftModel, LoraConfig, get_peft_model model_types = PeftModel, transformers.PreTrainedModel + _peft_installed = True except ImportError: + _peft_installed = False model_types = transformers.PreTrainedModel, __all__ = ['ComposerHFCausalLM'] @@ -42,7 +44,7 @@ def print_trainable_parameters(model: nn.Module) -> None: # Prints the number of trainable parameters in the model. - if PeftModel not in model_types: + if not _peft_installed: raise ImportError('PEFT not installed. Run pip install -e ".[gpu,peft]"') trainable_params = 0 all_param = 0 @@ -241,7 +243,7 @@ def __init__(self, om_model_config: Union[DictConfig, # if om_model_config includes lora and peft is installed, add lora modules lora_cfg = om_model_config.get("lora", None) if lora_cfg is not None: - if PeftModel not in model_types: + if not _peft_installed: raise ImportError( 'cfg.model.lora is given but PEFT not installed. Run pip install -e ".[gpu,peft]"' )