Skip to content

Commit

Permalink
add act checkpoint at sub layer level (#720)
Browse files Browse the repository at this point in the history
* add act checkpoint at sub layer level

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Mihir Patel <[email protected]>

* address comments

* addess coments

* add log info

* fix pyright

* refactor

* better log info and error msg

* add test

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Mihir Patel <[email protected]>

* remove unneeded comments

---------

Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2023
1 parent 7899178 commit 8ba697c
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 2 deletions.
34 changes: 32 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias
from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
build_attn_bias)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
Expand Down Expand Up @@ -733,7 +735,35 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
log.info(
'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).'
)
return isinstance(module, MPTBlock)

mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) +
list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return isinstance(module, mod_types)

def prepare_inputs_for_generation(
self,
Expand Down
73 changes: 73 additions & 0 deletions tests/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
from composer import Trainer
from composer.utils import get_device
from omegaconf import OmegaConf as om
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
CheckpointWrapper

from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM


@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize(
'activation_checkpointing_target',
[[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']])
def test_fsdp_act_checkpoint(activation_checkpointing: bool,
activation_checkpointing_target: list):
device = get_device('gpu')
model_cfg = {
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
'attn_config': {
'attn_type': 'grouped_query_attention',
'kv_n_heads': 2,
},
'activation_checkpointing_target': activation_checkpointing_target
}
model_cfg = om.create(model_cfg)

fsdp_config = {
'activation_checkpointing': activation_checkpointing,
'activation_checkpointing_reentrant': False,
'activation_cpu_offload': False,
}

model = ComposerMPTCausalLM(model_cfg)
model = device.module_to_device(model)

trainer = Trainer(
model=model,
device='gpu',
fsdp_config=fsdp_config,
)

assert trainer.state.fsdp_enabled
if not activation_checkpointing:
assert not isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0], CheckpointWrapper)
elif (not activation_checkpointing_target
) or activation_checkpointing_target == [
'mptblock', 'grouped_query_attention'
]:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module, CheckpointWrapper)
elif activation_checkpointing_target == ['grouped_query_attention']:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
else:
raise ValueError(
f'Unknown activation_checkpointing_target: {activation_checkpointing_target}'
)

0 comments on commit 8ba697c

Please sign in to comment.