diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6cfff12200..fd963b24ac 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -908,13 +908,46 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: + """The MPT activation checkpointing (act ckpt) function. + + When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below: + 1. null (or no such field), then ack ckpt the whole MPTBlock on all layers + 2. a list of modules to act ckpt, e.g., + activation_checkpointing_target: + - grouped_query_attention + - mptmlp + , then ack ckpt all the modules in the list on all layers + 3. a dictionary of module name with target_blocks, e.g., + activation_checkpointing_target: + { + "mptblock": target_blocks_1, + "grouped_query_attention": target_blocks_2 + } + target_blocks (target_blocks_1, target_blocks_2 above) can be: + - a single integer n means the first n transformer block will be candidates for act ckpt + - a string of first-m, middle-n, last-k means the first m, middle n and the last k layers are candidates for act ckpt. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks are candidates for act ckpt. + middle-n is range [start, end) where ``start = max(max_block_idx // 2 - n // 2, 0), end = min(start + n, max_block_idx + 1)`` + - a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block is the candidate for act ckpt. [2, 3] means the second and third transformer block are candidates. + a list of mixed integers and strings of first-m, middle-n, last-k + + An example in yaml config file:, + fsdp_config: + activation_checkpointing: true + model: + activation_checkpointing_target: + { + "mptblock": 'first-5', + "grouped_query_attention": 'last-35' + } + does full act ckpt on the first 5 layers and then ack ckpt the grouped_query_attention on the last 35 layers + """ if not hasattr(module, 'block_idx'): log.debug( f'No activating checkpointing for {module.__class__.__name__}, only transformer block or its submodules are eligible for activation checkpointing.' ) return False - def get_act_ckpt_module(mod_name: str) -> nn.Module: + def get_act_ckpt_module(mod_name: str) -> Any: if mod_name.lower() == 'mptblock': mod_type = MPTBlock elif mod_name in ATTN_CLASS_REGISTRY: