diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index f49b1b88f8..075a36251b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -726,10 +726,6 @@ def param_init_fn(self, module: nn.Module) -> None: **self.config.init_config, ) - # FSDP Wrap function - def fsdp_wrap_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) - # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: return isinstance(module, MPTBlock) @@ -887,10 +883,6 @@ def param_init_fn(self, module: nn.Module) -> None: **self.config.init_config, ) - # FSDP Wrap function - def fsdp_wrap_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) - # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', @@ -995,6 +987,16 @@ def _reorder_cache( return reordered_past +def _fsdp_wrap_fn(self: Union[MPTModel, MPTForCausalLM], + module: nn.Module) -> bool: + # FSDP Wrap function for MPT Models + return isinstance(module, MPTBlock) + + +MPTModel.fsdp_wrap_fn = _fsdp_wrap_fn +MPTForCausalLM.fsdp_wrap_fn = _fsdp_wrap_fn + + class ComposerMPTCausalLM(HuggingFaceModel): def __init__(