Skip to content

Commit

Permalink
fsdp wrap refac
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Jan 18, 2024
1 parent 2e4f4b2 commit cae6511
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,10 +726,6 @@ def param_init_fn(self, module: nn.Module) -> None:
**self.config.init_config,
)

# FSDP Wrap function
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
Expand Down Expand Up @@ -887,10 +883,6 @@ def param_init_fn(self, module: nn.Module) -> None:
**self.config.init_config,
)

# FSDP Wrap function
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
Expand Down Expand Up @@ -995,6 +987,16 @@ def _reorder_cache(
return reordered_past


def _fsdp_wrap_fn(self: Union[MPTModel, MPTForCausalLM],
module: nn.Module) -> bool:
# FSDP Wrap function for MPT Models
return isinstance(module, MPTBlock)


MPTModel.fsdp_wrap_fn = _fsdp_wrap_fn
MPTForCausalLM.fsdp_wrap_fn = _fsdp_wrap_fn


class ComposerMPTCausalLM(HuggingFaceModel):

def __init__(
Expand Down

0 comments on commit cae6511

Please sign in to comment.