diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 73ab88ff53..1c0bb90ee8 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -103,10 +103,11 @@ def __init__( config_overrides=config_overrides, load_in_8bit=load_in_8bit, pretrained=pretrained, - prepare_for_fsdp=True, ) - self.transform_model(model) + model = self.transform_model(model) + + model = ComposerHFCausalLM.prepare_inner_model(model, init_device) train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( use_train_metrics=use_train_metrics, @@ -121,7 +122,7 @@ def __init__( peft_config_object = None if peft_config is not None: - peft_config_object = self._get_peft_config(peft_config) + peft_config_object = self.get_peft_config(peft_config) # Set up config args for the model construction and base classes super().__init__( @@ -190,7 +191,6 @@ def build_inner_model( config_overrides: Dict[str, Any], load_in_8bit: bool, pretrained: bool, - prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM. @@ -361,12 +361,9 @@ def _autoset_attn_implementation_monkeypatch( pretrained_lora_id_or_path, ) - if prepare_for_fsdp: - ComposerHFCausalLM.prepare_inner_model(model, init_device) return model - @staticmethod - def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': + def get_peft_config(self, peft_config_dict: Dict[str, Any]) -> 'PeftConfig': if peft_installed: from peft import LoraConfig peft_type = peft_config_dict.get('peft_type', '')