From 3d6506bec4f90247245b77869eb92f1e0b1606a3 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Fri, 1 Dec 2023 20:52:44 +0000 Subject: [PATCH 1/2] add single value support --- llmfoundry/models/mpt/modeling_mpt.py | 3 +++ tests/test_fsdp_act_checkpoint.py | 15 ++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d6b23c04d0..3f9089862a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -739,6 +739,9 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: def activation_checkpointing_fn(self, module: nn.Module) -> bool: act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', None) or ['MPTBlock'] + if not isinstance(act_ckpt_list, list): + # `activation_checkpointing_target` is a single value + act_ckpt_list = [act_ckpt_list] if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: if len(act_ckpt_list) > 1: diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py index 3b9a746708..a7e41a3fc2 100644 --- a/tests/test_fsdp_act_checkpoint.py +++ b/tests/test_fsdp_act_checkpoint.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Union + import pytest from composer import Trainer from composer.utils import get_device, using_torch_2 @@ -14,11 +16,12 @@ @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('activation_checkpointing', [True, False]) -@pytest.mark.parametrize( - 'activation_checkpointing_target', - [[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']]) +@pytest.mark.parametrize('activation_checkpointing_target', [ + 'grouped_query_attention', [], ['grouped_query_attention'], + ['mptblock', 'grouped_query_attention'] +]) def test_fsdp_act_checkpoint(activation_checkpointing: bool, - activation_checkpointing_target: list): + activation_checkpointing_target: Union[list, str]): device = get_device('gpu') model_cfg = { 'name': 'mpt_causal_lm', @@ -66,7 +69,9 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ 0]._fsdp_wrapped_module._fpw_module assert isinstance(module, CheckpointWrapper) - elif activation_checkpointing_target == ['grouped_query_attention']: + elif activation_checkpointing_target == [ + 'grouped_query_attention' + ] or activation_checkpointing_target == 'grouped_query_attention': assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) From eddccc225be95e0adbdd83c9bb6ce33e0e7544c4 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Fri, 1 Dec 2023 21:07:30 +0000 Subject: [PATCH 2/2] check str or list dtype --- llmfoundry/models/mpt/modeling_mpt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3f9089862a..34b8992d3e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -739,9 +739,12 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: def activation_checkpointing_fn(self, module: nn.Module) -> bool: act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', None) or ['MPTBlock'] - if not isinstance(act_ckpt_list, list): - # `activation_checkpointing_target` is a single value + if isinstance(act_ckpt_list, str): act_ckpt_list = [act_ckpt_list] + elif not isinstance(act_ckpt_list, list): + raise ValueError( + f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}' + ) if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: if len(act_ckpt_list) > 1: