Skip to content

Commit

Permalink
fix hf_causal_lm
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jul 21, 2024
1 parent 82b3ae8 commit 49004fa
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ def __init__(
config_overrides=config_overrides,
load_in_8bit=load_in_8bit,
pretrained=pretrained,
prepare_for_fsdp=True,
)

self.transform_model(model)
model = self.transform_model(model)

model = ComposerHFCausalLM.prepare_inner_model(model, init_device)

train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics(
use_train_metrics=use_train_metrics,
Expand All @@ -121,7 +122,7 @@ def __init__(

peft_config_object = None
if peft_config is not None:
peft_config_object = self._get_peft_config(peft_config)
peft_config_object = self.get_peft_config(peft_config)

# Set up config args for the model construction and base classes
super().__init__(
Expand Down Expand Up @@ -190,7 +191,6 @@ def build_inner_model(
config_overrides: Dict[str, Any],
load_in_8bit: bool,
pretrained: bool,
prepare_for_fsdp: bool = False,
) -> Union[PreTrainedModel, 'PeftModel']:
"""Builds the inner model for the ComposerHFCausalLM.
Expand Down Expand Up @@ -361,12 +361,9 @@ def _autoset_attn_implementation_monkeypatch(
pretrained_lora_id_or_path,
)

if prepare_for_fsdp:
ComposerHFCausalLM.prepare_inner_model(model, init_device)
return model

@staticmethod
def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig':
def get_peft_config(self, peft_config_dict: Dict[str, Any]) -> 'PeftConfig':
if peft_installed:
from peft import LoraConfig
peft_type = peft_config_dict.get('peft_type', '')
Expand Down

0 comments on commit 49004fa

Please sign in to comment.