Skip to content

Commit

Permalink
Merge branch 'main' into shashank/seq_id_flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Dec 1, 2023
2 parents f1251c4 + 6ac01ef commit 5fca723
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
6 changes: 6 additions & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,12 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']
if isinstance(act_ckpt_list, str):
act_ckpt_list = [act_ckpt_list]
elif not isinstance(act_ckpt_list, list):
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}'
)

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
Expand Down
15 changes: 10 additions & 5 deletions tests/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import pytest
from composer import Trainer
from composer.utils import get_device, using_torch_2
Expand All @@ -14,11 +16,12 @@
@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize(
'activation_checkpointing_target',
[[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']])
@pytest.mark.parametrize('activation_checkpointing_target', [
'grouped_query_attention', [], ['grouped_query_attention'],
['mptblock', 'grouped_query_attention']
])
def test_fsdp_act_checkpoint(activation_checkpointing: bool,
activation_checkpointing_target: list):
activation_checkpointing_target: Union[list, str]):
device = get_device('gpu')
model_cfg = {
'name': 'mpt_causal_lm',
Expand Down Expand Up @@ -66,7 +69,9 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[
0]._fsdp_wrapped_module._fpw_module
assert isinstance(module, CheckpointWrapper)
elif activation_checkpointing_target == ['grouped_query_attention']:
elif activation_checkpointing_target == [
'grouped_query_attention'
] or activation_checkpointing_target == 'grouped_query_attention':
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
Expand Down

0 comments on commit 5fca723

Please sign in to comment.