Skip to content

Commit

Permalink
Add fully configurable activation checkpointing (#951)
Browse files Browse the repository at this point in the history
* add fully configurable activation checkpointing

* fix format

* fix format

* add docstring to activation_checkpointing_fn

* add block id range option in act ckpt

* resolve conflict

* add a check for blocks ids overlap in mapping

* fix typo

* update docstring

* refactor

* fix test

* Apply suggestions from code review

Co-authored-by: Mihir Patel <[email protected]>

* address comments

* add build mapping as a helper func

* fix format

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
cli99 and mvpatel2000 authored Feb 8, 2024
1 parent 60cdd0b commit 2e59620
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 48 deletions.
102 changes: 62 additions & 40 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
Expand Down Expand Up @@ -55,17 +56,11 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import \
FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP
from llmfoundry.models.layers.ffn import build_ffn as build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.mpt.configuration_mpt import MPTConfig

# NOTE: All utils are imported directly even if unused so that
Expand All @@ -87,6 +82,10 @@
MODEL_INIT_REGISTRY,
)

from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx,
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap)

try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func
except:
Expand Down Expand Up @@ -352,6 +351,13 @@ def __init__(self, config: MPTConfig):
**config.to_dict(),
) for _ in range(config.n_layers)
])

# Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx
for i, block in enumerate(self.blocks):
block.block_idx = i
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)

self.norm_f = norm_class(config.d_model, device=config.init_device)

self.rope = config.attn_config['rope']
Expand Down Expand Up @@ -908,41 +914,57 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']
if isinstance(act_ckpt_list, str):
act_ckpt_list = [act_ckpt_list]
elif not isinstance(act_ckpt_list, list):
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}'
"""The MPT activation checkpointing (act ckpt) function.
When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below:
1. null (or no such field): The whole MPTBlock will be activation checkpointed on all layers
2. a list of modules to act ckpt on all layers, e.g.,
activation_checkpointing_target:
- grouped_query_attention
- mptmlp
3. a dictionary of module name with target_blocks, e.g.,
activation_checkpointing_target:
{
"mptblock": target_blocks_1,
"grouped_query_attention": target_blocks_2
}
target_blocks (target_blocks_1, target_blocks_2 above) can be:
- a single integer n: the first n transformer block will be activation checkpointed
- a string of first-n, middle-m, last-k, range-i-j: the first n, the middle m, the last k, or the range [i, j) layers will be activation checkpointed. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks will be activation checkpointed
middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)``
- a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block will be activation checkpointed. [2, 3] means the second and third transformer blocks will be activation checkpointed
- a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j
An example in yaml config file:
fsdp_config:
activation_checkpointing: true
model:
activation_checkpointing_target:
{
"mptblock": 'first-5',
"grouped_query_attention": 'last-35'
}
"""
if not hasattr(module, 'block_idx'):
log.debug(
f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.'
)
return False

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
log.info(
'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).'
)
return isinstance(module, MPTBlock)

mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) +
list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return isinstance(module, mod_types)
act_ckpt_target = getattr(self.config,
'activation_checkpointing_target', None)
act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks(
act_ckpt_target, MPTBlock, module.max_block_idx)

check_mapping_blocks_overlap(act_ckpt_mod_to_blocks,
module.max_block_idx)

for k in act_ckpt_mod_to_blocks.keys():
if isinstance(module, k):
blocks = act_ckpt_mod_to_blocks[k]
return True if blocks == -1 else module.block_idx in blocks

return False

def prepare_inputs_for_generation(
self,
Expand Down
147 changes: 147 additions & 0 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import torch

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def pass_on_block_idx(parent: torch.nn.Module):
if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
return
for child in parent.children():
child.block_idx = parent.block_idx
child.max_block_idx = parent.max_block_idx
if child.children():
pass_on_block_idx(child)


def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_type = NORM_CLASS_REGISTRY[mod_name]
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return mod_type


def parse_ele_str(ele: str, max_block_idx: int) -> list:
"""Parse a string in target_blocks and return a list of block ids to add.
Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
to the first n, the middle m, the last k, and the range [i, j).
"""
to_add = None
if ele.startswith('first-'):
assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
elif ele.startswith('last-'):
assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(
range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
elif ele.startswith('middle-'):
assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
num = int(ele[7:])
start = max(max_block_idx // 2 - num // 2, 0)
end = min(start + num, max_block_idx + 1)
to_add = list(range(start, end))
elif ele.startswith('range-'):
r = ele[6:].split('-')
assert len(r) == 2, f'Invalid target_blocks element {ele}'
start, end = int(r[0]), int(r[1])
start = max(start, 0)
end = min(end, max_block_idx + 1)
to_add = list(range(start, end))
else:
raise ValueError(f'Invalid target_blocks element {ele}')
return to_add


def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
"""Parse the user input and return a list of block ids."""
candidate_block_ids = []
if isinstance(target_blocks, int):
candidate_block_ids = list(range(target_blocks))
elif isinstance(target_blocks, list):
for ele in target_blocks:
if isinstance(ele, int):
candidate_block_ids.append(ele)
elif isinstance(ele, str):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}'
)
elif isinstance(target_blocks, str):
target_blocks = target_blocks.replace(' ', '')
for ele in target_blocks.split(','):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}'
)

candidate_block_ids = list(set(candidate_block_ids))
return candidate_block_ids


def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
"""Check if the block ids in the mapping overlap with each other."""
all_blocks = [None] * (max_block_idx + 1)
for k, v in mapping.items():
if v == -1:
v = list(range(max_block_idx + 1))
for vv in v:
if vv < 0 or vv > max_block_idx:
continue
else:
if all_blocks[vv] is not None:
raise ValueError(
f'Block {vv} is assigned to both {k} and {all_blocks[vv]}.'
)
else:
all_blocks[vv] = k


def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any,
max_block_idx: int) -> dict:
act_ckpt_mod_to_blocks = {}
if act_ckpt_target is None or act_ckpt_target == []:
mod = top_module
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, str):
mod = get_act_ckpt_module(act_ckpt_target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, list):
for target in act_ckpt_target:
mod = get_act_ckpt_module(target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, dict):
for k, v in act_ckpt_target.items():
mod = get_act_ckpt_module(k)
block_ids = get_target_block_list(v, max_block_idx)
act_ckpt_mod_to_blocks[mod] = block_ids
else:
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}'
)

return act_ckpt_mod_to_blocks
29 changes: 21 additions & 8 deletions tests/models/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
@pytest.mark.gpu
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize('activation_checkpointing_target', [
'grouped_query_attention', [], ['grouped_query_attention'],
['mptblock', 'grouped_query_attention']
'grouped_query_attention', [], ['grouped_query_attention'], {
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}
])
def test_fsdp_act_checkpoint(activation_checkpointing: bool,
activation_checkpointing_target: Union[list, str]):
activation_checkpointing_target: Union[list, str,
dict]):
device = get_device('gpu')
model_cfg = {
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'n_layers': 3,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
Expand Down Expand Up @@ -59,10 +62,7 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
assert not isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0], CheckpointWrapper)
elif (not activation_checkpointing_target
) or activation_checkpointing_target == [
'mptblock', 'grouped_query_attention'
]:
elif (not activation_checkpointing_target):
module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[
0]._fsdp_wrapped_module
assert isinstance(module, CheckpointWrapper)
Expand All @@ -72,6 +72,19 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
elif activation_checkpointing_target == {
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[1]._fsdp_wrapped_module, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[2]._fsdp_wrapped_module.attn, CheckpointWrapper)
else:
raise ValueError(
f'Unknown activation_checkpointing_target: {activation_checkpointing_target}'
Expand Down

0 comments on commit 2e59620

Please sign in to comment.