Skip to content

Commit

Permalink
add docstring to activation_checkpointing_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Feb 6, 2024
1 parent 05b4c0d commit 4895ee0
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,13 +908,46 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
"""The MPT activation checkpointing (act ckpt) function.
When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below:
1. null (or no such field), then ack ckpt the whole MPTBlock on all layers
2. a list of modules to act ckpt, e.g.,
activation_checkpointing_target:
- grouped_query_attention
- mptmlp
, then ack ckpt all the modules in the list on all layers
3. a dictionary of module name with target_blocks, e.g.,
activation_checkpointing_target:
{
"mptblock": target_blocks_1,
"grouped_query_attention": target_blocks_2
}
target_blocks (target_blocks_1, target_blocks_2 above) can be:
- a single integer n means the first n transformer block will be candidates for act ckpt
- a string of first-m, middle-n, last-k means the first m, middle n and the last k layers are candidates for act ckpt. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks are candidates for act ckpt.
middle-n is range [start, end) where ``start = max(max_block_idx // 2 - n // 2, 0), end = min(start + n, max_block_idx + 1)``
- a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block is the candidate for act ckpt. [2, 3] means the second and third transformer block are candidates.
a list of mixed integers and strings of first-m, middle-n, last-k
An example in yaml config file:,
fsdp_config:
activation_checkpointing: true
model:
activation_checkpointing_target:
{
"mptblock": 'first-5',
"grouped_query_attention": 'last-35'
}
does full act ckpt on the first 5 layers and then ack ckpt the grouped_query_attention on the last 35 layers
"""
if not hasattr(module, 'block_idx'):
log.debug(
f'No activating checkpointing for {module.__class__.__name__}, only transformer block or its submodules are eligible for activation checkpointing.'
)
return False

def get_act_ckpt_module(mod_name: str) -> nn.Module:
def get_act_ckpt_module(mod_name: str) -> Any:
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
Expand Down

0 comments on commit 4895ee0

Please sign in to comment.