Skip to content

Commit

Permalink
Fix indent
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Oct 30, 2023
1 parent 79cf8d6 commit 7f72c25
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,40 +240,40 @@ def __init__(self, om_model_config: Union[DictConfig,

z_loss = om_model_config.get('z_loss', 0.0)

# if om_model_config includes lora and peft is installed, add lora modules
lora_cfg = om_model_config.get("lora", None)
if lora_cfg is not None:
if not _peft_installed:
raise ImportError(
'cfg.model.lora is given but PEFT not installed. Run pip install -e ".[gpu,peft]"'
)
# if om_model_config includes lora and peft is installed, add lora modules
lora_cfg = om_model_config.get("lora", None)
if lora_cfg is not None:
if not _peft_installed:
raise ImportError(
'cfg.model.lora is given but PEFT not installed. Run pip install -e ".[gpu,peft]"'
)

validate_lora_config(lora_cfg)
validate_lora_config(lora_cfg)

print("Building Lora config...")
lora_cfg = LoraConfig(**lora_cfg)
print("Lora config built.")
print("Adding Lora modules...")
model = get_peft_model(model, lora_cfg)
print("Lora modules added.")
print_trainable_parameters(model)
print("Building Lora config...")
lora_cfg = LoraConfig(**lora_cfg)
print("Lora config built.")
print("Adding Lora modules...")
model = get_peft_model(model, lora_cfg)
print("Lora modules added.")
print_trainable_parameters(model)

attention_patch_type = om_model_config.get('attention_patch_type',
None)
if attention_patch_type is not None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)
attention_patch_type = om_model_config.get('attention_patch_type',
None)
if attention_patch_type is not None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)

log.debug(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
log.debug(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False

# elif the model is either a PeftModel or a PreTrainedModel
elif isinstance(om_model_config, model_types):
Expand Down

0 comments on commit 7f72c25

Please sign in to comment.