diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 10c042d27c..274c1b76e5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -45,7 +45,9 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias +from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, + attn_bias_shape, + build_attn_bias) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY @@ -733,7 +735,35 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', + None) or ['MPTBlock'] + + if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: + if len(act_ckpt_list) > 1: + log.info( + 'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).' + ) + return isinstance(module, MPTBlock) + + mod_types = () + for mod_name in act_ckpt_list: + if mod_name.lower() == 'mptblock': + mod_types += (MPTBlock,) + elif mod_name in ATTN_CLASS_REGISTRY: + mod_types += (ATTN_CLASS_REGISTRY[mod_name],) + elif mod_name in FFN_CLASS_REGISTRY: + mod_types += (FFN_CLASS_REGISTRY[mod_name],) + elif mod_name in NORM_CLASS_REGISTRY: + mod_types += (NORM_CLASS_REGISTRY[mod_name],) + else: + msg = ', '.join( + list(ATTN_CLASS_REGISTRY.keys()) + + list(FFN_CLASS_REGISTRY.keys()) + + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) + raise ValueError( + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' + ) + return isinstance(module, mod_types) def prepare_inputs_for_generation( self, diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py new file mode 100644 index 0000000000..1a46fcbccd --- /dev/null +++ b/tests/test_fsdp_act_checkpoint.py @@ -0,0 +1,73 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from composer import Trainer +from composer.utils import get_device +from omegaconf import OmegaConf as om +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + CheckpointWrapper + +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM + + +@pytest.mark.world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize('activation_checkpointing', [True, False]) +@pytest.mark.parametrize( + 'activation_checkpointing_target', + [[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']]) +def test_fsdp_act_checkpoint(activation_checkpointing: bool, + activation_checkpointing_target: list): + device = get_device('gpu') + model_cfg = { + 'name': 'mpt_causal_lm', + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 2, + 'expansion_ratio': 1, + 'max_seq_len': 16, + 'vocab_size': 50368, + 'attn_config': { + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 2, + }, + 'activation_checkpointing_target': activation_checkpointing_target + } + model_cfg = om.create(model_cfg) + + fsdp_config = { + 'activation_checkpointing': activation_checkpointing, + 'activation_checkpointing_reentrant': False, + 'activation_cpu_offload': False, + } + + model = ComposerMPTCausalLM(model_cfg) + model = device.module_to_device(model) + + trainer = Trainer( + model=model, + device='gpu', + fsdp_config=fsdp_config, + ) + + assert trainer.state.fsdp_enabled + if not activation_checkpointing: + assert not isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0], CheckpointWrapper) + elif (not activation_checkpointing_target + ) or activation_checkpointing_target == [ + 'mptblock', 'grouped_query_attention' + ]: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module, CheckpointWrapper) + elif activation_checkpointing_target == ['grouped_query_attention']: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + else: + raise ValueError( + f'Unknown activation_checkpointing_target: {activation_checkpointing_target}' + )