From 24246598ef8f13bb5b358fd78619867b161682b2 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 7 Nov 2023 22:05:47 +0000 Subject: [PATCH 01/11] add act checkpoint at sub layer level --- llmfoundry/models/mpt/modeling_mpt.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0cb3ebd56c..1965166e07 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -45,7 +45,9 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias +from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, + attn_bias_shape, + build_attn_bias) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY @@ -705,7 +707,25 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + if not hasattr(self.config, 'activation_checkpointing_target'): + return isinstance(module, MPTBlock) + act_ckpt_str = self.config.activation_checkpointing_target + act_ckpt_lst = act_ckpt_str.replace(' ', '').split(',') + if act_ckpt_lst: + if 'MPTBlock' in act_ckpt_lst or 'mptblock' in act_ckpt_lst: + act_ckpt_lst = ['MPTBlock'] + for mod_name in act_ckpt_lst: + if mod_name in ['MPTBlock', 'mptblock']: + mod_type = MPTBlock + elif mod_name in ATTN_CLASS_REGISTRY: + mod_type = ATTN_CLASS_REGISTRY[mod_name] + elif mod_name in FFN_CLASS_REGISTRY: + mod_type = FFN_CLASS_REGISTRY[mod_name] + elif mod_name in NORM_CLASS_REGISTRY: + mod_type = NORM_CLASS_REGISTRY[mod_name] + else: + continue + return isinstance(module, mod_type) def prepare_inputs_for_generation( self, From 952b9a54f16d8c4800f85e034979408d1b6fa32e Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 7 Nov 2023 14:16:52 -0800 Subject: [PATCH 02/11] Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Mihir Patel --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1965166e07..eb75315ca4 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -715,7 +715,7 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool: if 'MPTBlock' in act_ckpt_lst or 'mptblock' in act_ckpt_lst: act_ckpt_lst = ['MPTBlock'] for mod_name in act_ckpt_lst: - if mod_name in ['MPTBlock', 'mptblock']: + if mod_name.lower() == 'mptblock': mod_type = MPTBlock elif mod_name in ATTN_CLASS_REGISTRY: mod_type = ATTN_CLASS_REGISTRY[mod_name] From 4eb90294a88bbcaf3a7a166f21e08ee61e186d61 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 7 Nov 2023 22:22:59 +0000 Subject: [PATCH 03/11] address comments --- llmfoundry/models/mpt/modeling_mpt.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index eb75315ca4..7d51200f35 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -710,9 +710,9 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool: if not hasattr(self.config, 'activation_checkpointing_target'): return isinstance(module, MPTBlock) act_ckpt_str = self.config.activation_checkpointing_target - act_ckpt_lst = act_ckpt_str.replace(' ', '').split(',') - if act_ckpt_lst: - if 'MPTBlock' in act_ckpt_lst or 'mptblock' in act_ckpt_lst: + act_ckpt_list = act_ckpt_str.replace(' ', '').split(',') + if act_ckpt_list: + if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: act_ckpt_lst = ['MPTBlock'] for mod_name in act_ckpt_lst: if mod_name.lower() == 'mptblock': @@ -724,6 +724,9 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool: elif mod_name in NORM_CLASS_REGISTRY: mod_type = NORM_CLASS_REGISTRY[mod_name] else: + warnings.warn( + f'module name specified in activation_checkpointing_target ({mod_name}) not recognized, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.' + ) continue return isinstance(module, mod_type) From 4cb619ec95c2f604e77ca30eac10e61960c0f5ef Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 8 Nov 2023 18:08:10 +0000 Subject: [PATCH 04/11] addess coments --- llmfoundry/models/mpt/modeling_mpt.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 7d51200f35..8523d6b9aa 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -709,26 +709,28 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: def activation_checkpointing_fn(self, module: nn.Module) -> bool: if not hasattr(self.config, 'activation_checkpointing_target'): return isinstance(module, MPTBlock) - act_ckpt_str = self.config.activation_checkpointing_target - act_ckpt_list = act_ckpt_str.replace(' ', '').split(',') + act_ckpt_list = self.config.activation_checkpointing_target if act_ckpt_list: if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: - act_ckpt_lst = ['MPTBlock'] - for mod_name in act_ckpt_lst: + act_ckpt_list = ['MPTBlock'] + warnings.warn( + f'activation checkpointing MPTBlock, ignoring other sub-block modules if specified' + ) + mod_types = () + for mod_name in act_ckpt_list: if mod_name.lower() == 'mptblock': - mod_type = MPTBlock + mod_types += (MPTBlock,) elif mod_name in ATTN_CLASS_REGISTRY: - mod_type = ATTN_CLASS_REGISTRY[mod_name] + mod_types += (ATTN_CLASS_REGISTRY[mod_name],) elif mod_name in FFN_CLASS_REGISTRY: - mod_type = FFN_CLASS_REGISTRY[mod_name] + mod_types += (FFN_CLASS_REGISTRY[mod_name],) elif mod_name in NORM_CLASS_REGISTRY: - mod_type = NORM_CLASS_REGISTRY[mod_name] + mod_types += (NORM_CLASS_REGISTRY[mod_name],) else: warnings.warn( f'module name specified in activation_checkpointing_target ({mod_name}) not recognized, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.' ) - continue - return isinstance(module, mod_type) + return isinstance(module, mod_types) def prepare_inputs_for_generation( self, From 30251b2691b29e2d17b222e084fd4eb85ef434ee Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 8 Nov 2023 18:13:46 +0000 Subject: [PATCH 05/11] add log info --- llmfoundry/models/mpt/modeling_mpt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8523d6b9aa..4278624d14 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -708,6 +708,9 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: if not hasattr(self.config, 'activation_checkpointing_target'): + log.info( + f'activation checkpointing MPTBlock as activation_checkpointing_target is not set in model_config' + ) return isinstance(module, MPTBlock) act_ckpt_list = self.config.activation_checkpointing_target if act_ckpt_list: From 51beeefc86a432c8b178bf3810d57aa1d197eb33 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 8 Nov 2023 18:32:38 +0000 Subject: [PATCH 06/11] fix pyright --- llmfoundry/models/mpt/modeling_mpt.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4278624d14..3d40675e79 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -707,14 +707,15 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - if not hasattr(self.config, 'activation_checkpointing_target'): + if not hasattr(self.config, 'activation_checkpointing_target' + ) or self.config.activation_checkpointing_target is None: log.info( f'activation checkpointing MPTBlock as activation_checkpointing_target is not set in model_config' ) return isinstance(module, MPTBlock) - act_ckpt_list = self.config.activation_checkpointing_target - if act_ckpt_list: - if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: + else: + act_ckpt_list = self.config.activation_checkpointing_target + if 'MPTBlock' in act_ckpt_list: act_ckpt_list = ['MPTBlock'] warnings.warn( f'activation checkpointing MPTBlock, ignoring other sub-block modules if specified' From a22ce6d865736969c0555cfb1d9037124c7e0087 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 8 Nov 2023 22:59:38 +0000 Subject: [PATCH 07/11] refactor --- llmfoundry/models/mpt/modeling_mpt.py | 44 ++++++++++++--------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3d40675e79..fee0df3dd9 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -707,34 +707,30 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - if not hasattr(self.config, 'activation_checkpointing_target' - ) or self.config.activation_checkpointing_target is None: + act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', + None) or ['MPTBlock'] + + if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: log.info( - f'activation checkpointing MPTBlock as activation_checkpointing_target is not set in model_config' + 'Activation checkpointing MPTBlock only (ignoring other sub-block modules if specified in activation_checkpointing_target).' ) return isinstance(module, MPTBlock) - else: - act_ckpt_list = self.config.activation_checkpointing_target - if 'MPTBlock' in act_ckpt_list: - act_ckpt_list = ['MPTBlock'] - warnings.warn( - f'activation checkpointing MPTBlock, ignoring other sub-block modules if specified' + + mod_types = () + for mod_name in act_ckpt_list: + if mod_name.lower() == 'mptblock': + mod_types += (MPTBlock,) + elif mod_name in ATTN_CLASS_REGISTRY: + mod_types += (ATTN_CLASS_REGISTRY[mod_name],) + elif mod_name in FFN_CLASS_REGISTRY: + mod_types += (FFN_CLASS_REGISTRY[mod_name],) + elif mod_name in NORM_CLASS_REGISTRY: + mod_types += (NORM_CLASS_REGISTRY[mod_name],) + else: + raise ValueError( + f'{mod_name=} (specified in activation_checkpointing_target) is not a recognized option, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.' ) - mod_types = () - for mod_name in act_ckpt_list: - if mod_name.lower() == 'mptblock': - mod_types += (MPTBlock,) - elif mod_name in ATTN_CLASS_REGISTRY: - mod_types += (ATTN_CLASS_REGISTRY[mod_name],) - elif mod_name in FFN_CLASS_REGISTRY: - mod_types += (FFN_CLASS_REGISTRY[mod_name],) - elif mod_name in NORM_CLASS_REGISTRY: - mod_types += (NORM_CLASS_REGISTRY[mod_name],) - else: - warnings.warn( - f'module name specified in activation_checkpointing_target ({mod_name}) not recognized, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.' - ) - return isinstance(module, mod_types) + return isinstance(module, mod_types) def prepare_inputs_for_generation( self, From 48e3c4f438265d40ecdbc2a0b4bb6be345ad2b9e Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Thu, 9 Nov 2023 01:08:54 +0000 Subject: [PATCH 08/11] better log info and error msg --- llmfoundry/models/mpt/modeling_mpt.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index fee0df3dd9..aea03d4a9a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -711,9 +711,10 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool: None) or ['MPTBlock'] if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: - log.info( - 'Activation checkpointing MPTBlock only (ignoring other sub-block modules if specified in activation_checkpointing_target).' - ) + if len(act_ckpt_list) > 1: + log.info( + 'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).' + ) return isinstance(module, MPTBlock) mod_types = () @@ -727,8 +728,12 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool: elif mod_name in NORM_CLASS_REGISTRY: mod_types += (NORM_CLASS_REGISTRY[mod_name],) else: + msg = ', '.join( + list(ATTN_CLASS_REGISTRY.keys()) + + list(FFN_CLASS_REGISTRY.keys()) + + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) raise ValueError( - f'{mod_name=} (specified in activation_checkpointing_target) is not a recognized option, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.' + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option, available options are {msg}.' ) return isinstance(module, mod_types) From 7417d5a852aef5380f2b29c16c26554ac7c1acc6 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Thu, 9 Nov 2023 06:15:22 +0000 Subject: [PATCH 09/11] add test --- tests/test_fsdp_act_checkpoint.py | 74 +++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/test_fsdp_act_checkpoint.py diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py new file mode 100644 index 0000000000..592b609410 --- /dev/null +++ b/tests/test_fsdp_act_checkpoint.py @@ -0,0 +1,74 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from composer import Trainer +from composer.utils import get_device +from omegaconf import OmegaConf as om +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + CheckpointWrapper + +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM + + +# python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_fsdp_act_checkpoint.py::test_fsdp_act_checkpoint # noqa +@pytest.mark.world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize('activation_checkpointing', [True, False]) +@pytest.mark.parametrize( + 'activation_checkpointing_target', + [[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']]) +def test_fsdp_act_checkpoint(activation_checkpointing: bool, + activation_checkpointing_target: list): + device = get_device('gpu') + model_cfg = { + 'name': 'mpt_causal_lm', + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 2, + 'expansion_ratio': 1, + 'max_seq_len': 16, + 'vocab_size': 50368, + 'attn_config': { + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 2, + }, + 'activation_checkpointing_target': activation_checkpointing_target + } + model_cfg = om.create(model_cfg) + + fsdp_config = { + 'activation_checkpointing': activation_checkpointing, + 'activation_checkpointing_reentrant': False, + 'activation_cpu_offload': False, + } + + model = ComposerMPTCausalLM(model_cfg) + model = device.module_to_device(model) + + trainer = Trainer( + model=model, + device='gpu', + fsdp_config=fsdp_config, + ) + + assert trainer.state.fsdp_enabled + if not activation_checkpointing: + assert not isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0], CheckpointWrapper) + elif (not activation_checkpointing_target + ) or activation_checkpointing_target == [ + 'mptblock', 'grouped_query_attention' + ]: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module, CheckpointWrapper) + elif activation_checkpointing_target == ['grouped_query_attention']: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + else: + raise ValueError( + f'Unknown activation_checkpointing_target: {activation_checkpointing_target}' + ) From d44fddefbafc8812ce1a26905475c328aafeab7e Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 8 Nov 2023 22:16:26 -0800 Subject: [PATCH 10/11] Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Mihir Patel --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index aea03d4a9a..d36ba9333e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -733,7 +733,7 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool: list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) raise ValueError( - f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option, available options are {msg}.' + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' ) return isinstance(module, mod_types) From 0e4d729164bac439239c61fd26a5fa8125f8c16b Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Mon, 13 Nov 2023 21:05:49 +0000 Subject: [PATCH 11/11] remove unneeded comments --- tests/test_fsdp_act_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py index 592b609410..1a46fcbccd 100644 --- a/tests/test_fsdp_act_checkpoint.py +++ b/tests/test_fsdp_act_checkpoint.py @@ -11,7 +11,6 @@ from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM -# python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_fsdp_act_checkpoint.py::test_fsdp_act_checkpoint # noqa @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('activation_checkpointing', [True, False])