From 2e59620a312f900d37a1e17ad2ff9e1cbb6b09ad Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Thu, 8 Feb 2024 14:40:53 -0800 Subject: [PATCH] Add fully configurable activation checkpointing (#951) * add fully configurable activation checkpointing * fix format * fix format * add docstring to activation_checkpointing_fn * add block id range option in act ckpt * resolve conflict * add a check for blocks ids overlap in mapping * fix typo * update docstring * refactor * fix test * Apply suggestions from code review Co-authored-by: Mihir Patel * address comments * add build mapping as a helper func * fix format --------- Co-authored-by: Mihir Patel --- llmfoundry/models/mpt/modeling_mpt.py | 102 ++++++++++------ llmfoundry/models/utils/act_ckpt.py | 147 +++++++++++++++++++++++ tests/models/test_fsdp_act_checkpoint.py | 29 +++-- 3 files changed, 230 insertions(+), 48 deletions(-) create mode 100644 llmfoundry/models/utils/act_ckpt.py diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 79dc8c7f25..e9ad8054e2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -28,6 +28,7 @@ from llmfoundry.models.layers.attention import (is_flash_v1_installed, is_flash_v2_installed) +from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above @@ -55,17 +56,11 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, - attn_bias_shape, +from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import \ - FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP from llmfoundry.models.layers.ffn import build_ffn as build_ffn -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -87,6 +82,10 @@ MODEL_INIT_REGISTRY, ) +from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, + build_act_ckpt_mod_to_blocks, + check_mapping_blocks_overlap) + try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func except: @@ -352,6 +351,13 @@ def __init__(self, config: MPTConfig): **config.to_dict(), ) for _ in range(config.n_layers) ]) + + # Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx + for i, block in enumerate(self.blocks): + block.block_idx = i + block.max_block_idx = config.n_layers - 1 + pass_on_block_idx(block) + self.norm_f = norm_class(config.d_model, device=config.init_device) self.rope = config.attn_config['rope'] @@ -908,41 +914,57 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', - None) or ['MPTBlock'] - if isinstance(act_ckpt_list, str): - act_ckpt_list = [act_ckpt_list] - elif not isinstance(act_ckpt_list, list): - raise ValueError( - f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}' + """The MPT activation checkpointing (act ckpt) function. + + When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below: + 1. null (or no such field): The whole MPTBlock will be activation checkpointed on all layers + 2. a list of modules to act ckpt on all layers, e.g., + activation_checkpointing_target: + - grouped_query_attention + - mptmlp + 3. a dictionary of module name with target_blocks, e.g., + activation_checkpointing_target: + { + "mptblock": target_blocks_1, + "grouped_query_attention": target_blocks_2 + } + target_blocks (target_blocks_1, target_blocks_2 above) can be: + - a single integer n: the first n transformer block will be activation checkpointed + - a string of first-n, middle-m, last-k, range-i-j: the first n, the middle m, the last k, or the range [i, j) layers will be activation checkpointed. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks will be activation checkpointed + middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)`` + - a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block will be activation checkpointed. [2, 3] means the second and third transformer blocks will be activation checkpointed + - a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j + + An example in yaml config file: + fsdp_config: + activation_checkpointing: true + model: + activation_checkpointing_target: + { + "mptblock": 'first-5', + "grouped_query_attention": 'last-35' + } + """ + if not hasattr(module, 'block_idx'): + log.debug( + f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.' ) + return False - if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: - if len(act_ckpt_list) > 1: - log.info( - 'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).' - ) - return isinstance(module, MPTBlock) - - mod_types = () - for mod_name in act_ckpt_list: - if mod_name.lower() == 'mptblock': - mod_types += (MPTBlock,) - elif mod_name in ATTN_CLASS_REGISTRY: - mod_types += (ATTN_CLASS_REGISTRY[mod_name],) - elif mod_name in FFN_CLASS_REGISTRY: - mod_types += (FFN_CLASS_REGISTRY[mod_name],) - elif mod_name in NORM_CLASS_REGISTRY: - mod_types += (NORM_CLASS_REGISTRY[mod_name],) - else: - msg = ', '.join( - list(ATTN_CLASS_REGISTRY.keys()) + - list(FFN_CLASS_REGISTRY.keys()) + - list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) - raise ValueError( - f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' - ) - return isinstance(module, mod_types) + act_ckpt_target = getattr(self.config, + 'activation_checkpointing_target', None) + act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks( + act_ckpt_target, MPTBlock, module.max_block_idx) + + check_mapping_blocks_overlap(act_ckpt_mod_to_blocks, + module.max_block_idx) + + for k in act_ckpt_mod_to_blocks.keys(): + if isinstance(module, k): + blocks = act_ckpt_mod_to_blocks[k] + return True if blocks == -1 else module.block_idx in blocks + + return False def prepare_inputs_for_generation( self, diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py new file mode 100644 index 0000000000..08b718929a --- /dev/null +++ b/llmfoundry/models/utils/act_ckpt.py @@ -0,0 +1,147 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch + +from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY +from llmfoundry.models.layers.blocks import MPTBlock +from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY +from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY + + +def pass_on_block_idx(parent: torch.nn.Module): + if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'): + return + for child in parent.children(): + child.block_idx = parent.block_idx + child.max_block_idx = parent.max_block_idx + if child.children(): + pass_on_block_idx(child) + + +def get_act_ckpt_module(mod_name: str) -> Any: + """Get the module type from the module name.""" + if mod_name.lower() == 'mptblock': + mod_type = MPTBlock + elif mod_name in ATTN_CLASS_REGISTRY: + mod_type = ATTN_CLASS_REGISTRY[mod_name] + elif mod_name in FFN_CLASS_REGISTRY: + mod_type = FFN_CLASS_REGISTRY[mod_name] + elif mod_name in NORM_CLASS_REGISTRY: + mod_type = NORM_CLASS_REGISTRY[mod_name] + else: + msg = ', '.join( + list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) + raise ValueError( + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' + ) + return mod_type + + +def parse_ele_str(ele: str, max_block_idx: int) -> list: + """Parse a string in target_blocks and return a list of block ids to add. + + Supported formats are: first-n, middle-m, last-k, range-i-j which correspond + to the first n, the middle m, the last k, and the range [i, j). + """ + to_add = None + if ele.startswith('first-'): + assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}' + to_add = list(range(min(int(ele[6:]), max_block_idx + 1))) + elif ele.startswith('last-'): + assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}' + to_add = list( + range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1)) + elif ele.startswith('middle-'): + assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}' + num = int(ele[7:]) + start = max(max_block_idx // 2 - num // 2, 0) + end = min(start + num, max_block_idx + 1) + to_add = list(range(start, end)) + elif ele.startswith('range-'): + r = ele[6:].split('-') + assert len(r) == 2, f'Invalid target_blocks element {ele}' + start, end = int(r[0]), int(r[1]) + start = max(start, 0) + end = min(end, max_block_idx + 1) + to_add = list(range(start, end)) + else: + raise ValueError(f'Invalid target_blocks element {ele}') + return to_add + + +def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list: + """Parse the user input and return a list of block ids.""" + candidate_block_ids = [] + if isinstance(target_blocks, int): + candidate_block_ids = list(range(target_blocks)) + elif isinstance(target_blocks, list): + for ele in target_blocks: + if isinstance(ele, int): + candidate_block_ids.append(ele) + elif isinstance(ele, str): + to_add = parse_ele_str(ele, max_block_idx) + candidate_block_ids.extend(to_add) + else: + raise ValueError( + f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}' + ) + elif isinstance(target_blocks, str): + target_blocks = target_blocks.replace(' ', '') + for ele in target_blocks.split(','): + to_add = parse_ele_str(ele, max_block_idx) + candidate_block_ids.extend(to_add) + else: + raise ValueError( + f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}' + ) + + candidate_block_ids = list(set(candidate_block_ids)) + return candidate_block_ids + + +def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None: + """Check if the block ids in the mapping overlap with each other.""" + all_blocks = [None] * (max_block_idx + 1) + for k, v in mapping.items(): + if v == -1: + v = list(range(max_block_idx + 1)) + for vv in v: + if vv < 0 or vv > max_block_idx: + continue + else: + if all_blocks[vv] is not None: + raise ValueError( + f'Block {vv} is assigned to both {k} and {all_blocks[vv]}.' + ) + else: + all_blocks[vv] = k + + +def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, + max_block_idx: int) -> dict: + act_ckpt_mod_to_blocks = {} + if act_ckpt_target is None or act_ckpt_target == []: + mod = top_module + act_ckpt_mod_to_blocks[mod] = -1 + elif isinstance(act_ckpt_target, str): + mod = get_act_ckpt_module(act_ckpt_target) + act_ckpt_mod_to_blocks[mod] = -1 + elif isinstance(act_ckpt_target, list): + for target in act_ckpt_target: + mod = get_act_ckpt_module(target) + act_ckpt_mod_to_blocks[mod] = -1 + elif isinstance(act_ckpt_target, dict): + for k, v in act_ckpt_target.items(): + mod = get_act_ckpt_module(k) + block_ids = get_target_block_list(v, max_block_idx) + act_ckpt_mod_to_blocks[mod] = block_ids + else: + raise ValueError( + f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}' + ) + + return act_ckpt_mod_to_blocks diff --git a/tests/models/test_fsdp_act_checkpoint.py b/tests/models/test_fsdp_act_checkpoint.py index 987ea5f2a7..97063b25c4 100644 --- a/tests/models/test_fsdp_act_checkpoint.py +++ b/tests/models/test_fsdp_act_checkpoint.py @@ -17,17 +17,20 @@ @pytest.mark.gpu @pytest.mark.parametrize('activation_checkpointing', [True, False]) @pytest.mark.parametrize('activation_checkpointing_target', [ - 'grouped_query_attention', [], ['grouped_query_attention'], - ['mptblock', 'grouped_query_attention'] + 'grouped_query_attention', [], ['grouped_query_attention'], { + 'mptblock': [1], + 'grouped_query_attention': 'first-1, last-1' + } ]) def test_fsdp_act_checkpoint(activation_checkpointing: bool, - activation_checkpointing_target: Union[list, str]): + activation_checkpointing_target: Union[list, str, + dict]): device = get_device('gpu') model_cfg = { 'name': 'mpt_causal_lm', 'd_model': 128, 'n_heads': 4, - 'n_layers': 2, + 'n_layers': 3, 'expansion_ratio': 1, 'max_seq_len': 16, 'vocab_size': 50368, @@ -59,10 +62,7 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, assert not isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. blocks[0], CheckpointWrapper) - elif (not activation_checkpointing_target - ) or activation_checkpointing_target == [ - 'mptblock', 'grouped_query_attention' - ]: + elif (not activation_checkpointing_target): module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ 0]._fsdp_wrapped_module assert isinstance(module, CheckpointWrapper) @@ -72,6 +72,19 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + elif activation_checkpointing_target == { + 'mptblock': [1], + 'grouped_query_attention': 'first-1, last-1' + }: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[1]._fsdp_wrapped_module, CheckpointWrapper) + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[2]._fsdp_wrapped_module.attn, CheckpointWrapper) else: raise ValueError( f'Unknown activation_checkpointing_target: {activation_checkpointing_target}'