Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fully configurable activation checkpointing #951

Merged
merged 19 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 62 additions & 40 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
Expand Down Expand Up @@ -55,17 +56,11 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import \
FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP
from llmfoundry.models.layers.ffn import build_ffn as build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.mpt.configuration_mpt import MPTConfig

# NOTE: All utils are imported directly even if unused so that
Expand All @@ -87,6 +82,10 @@
MODEL_INIT_REGISTRY,
)

from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx,
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap)

try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func
except:
Expand Down Expand Up @@ -352,6 +351,13 @@ def __init__(self, config: MPTConfig):
**config.to_dict(),
) for _ in range(config.n_layers)
])

# Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx
for i, block in enumerate(self.blocks):
block.block_idx = i
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)
cli99 marked this conversation as resolved.
Show resolved Hide resolved

self.norm_f = norm_class(config.d_model, device=config.init_device)

self.rope = config.attn_config['rope']
Expand Down Expand Up @@ -908,41 +914,57 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']
if isinstance(act_ckpt_list, str):
act_ckpt_list = [act_ckpt_list]
elif not isinstance(act_ckpt_list, list):
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}'
"""The MPT activation checkpointing (act ckpt) function.

When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below:
1. null (or no such field): The whole MPTBlock will be activation checkpointed on all layers
2. a list of modules to act ckpt on all layers, e.g.,
activation_checkpointing_target:
- grouped_query_attention
- mptmlp
3. a dictionary of module name with target_blocks, e.g.,
activation_checkpointing_target:
{
"mptblock": target_blocks_1,
"grouped_query_attention": target_blocks_2
}
target_blocks (target_blocks_1, target_blocks_2 above) can be:
- a single integer n: the first n transformer block will be activation checkpointed
- a string of first-n, middle-m, last-k, range-i-j: the first n, the middle m, the last k, or the range [i, j) layers will be activation checkpointed. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks will be activation checkpointed
middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)``
- a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block will be activation checkpointed. [2, 3] means the second and third transformer blocks will be activation checkpointed
- a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j

An example in yaml config file:
fsdp_config:
activation_checkpointing: true
model:
activation_checkpointing_target:
{
"mptblock": 'first-5',
"grouped_query_attention": 'last-35'
}
"""
if not hasattr(module, 'block_idx'):
log.debug(
f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.'
)
return False

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
log.info(
'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).'
)
return isinstance(module, MPTBlock)

mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) +
list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return isinstance(module, mod_types)
act_ckpt_target = getattr(self.config,
'activation_checkpointing_target', None)
act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks(
act_ckpt_target, MPTBlock, module.max_block_idx)

check_mapping_blocks_overlap(act_ckpt_mod_to_blocks,
module.max_block_idx)

for k in act_ckpt_mod_to_blocks.keys():
if isinstance(module, k):
blocks = act_ckpt_mod_to_blocks[k]
return True if blocks == -1 else module.block_idx in blocks

return False

def prepare_inputs_for_generation(
self,
Expand Down
147 changes: 147 additions & 0 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import torch

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def pass_on_block_idx(parent: torch.nn.Module):
if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
return
for child in parent.children():
child.block_idx = parent.block_idx
child.max_block_idx = parent.max_block_idx
if child.children():
pass_on_block_idx(child)


def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_type = NORM_CLASS_REGISTRY[mod_name]
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return mod_type


def parse_ele_str(ele: str, max_block_idx: int) -> list:
"""Parse a string in target_blocks and return a list of block ids to add.

Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
to the first n, the middle m, the last k, and the range [i, j).
"""
to_add = None
if ele.startswith('first-'):
assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
elif ele.startswith('last-'):
assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(
range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
elif ele.startswith('middle-'):
assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
num = int(ele[7:])
start = max(max_block_idx // 2 - num // 2, 0)
end = min(start + num, max_block_idx + 1)
to_add = list(range(start, end))
elif ele.startswith('range-'):
r = ele[6:].split('-')
assert len(r) == 2, f'Invalid target_blocks element {ele}'
start, end = int(r[0]), int(r[1])
start = max(start, 0)
end = min(end, max_block_idx + 1)
to_add = list(range(start, end))
else:
raise ValueError(f'Invalid target_blocks element {ele}')
return to_add


def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
"""Parse the user input and return a list of block ids."""
candidate_block_ids = []
if isinstance(target_blocks, int):
candidate_block_ids = list(range(target_blocks))
elif isinstance(target_blocks, list):
for ele in target_blocks:
if isinstance(ele, int):
candidate_block_ids.append(ele)
elif isinstance(ele, str):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}'
)
elif isinstance(target_blocks, str):
target_blocks = target_blocks.replace(' ', '')
for ele in target_blocks.split(','):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}'
)

candidate_block_ids = list(set(candidate_block_ids))
return candidate_block_ids


def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
"""Check if the block ids in the mapping overlap with each other."""
all_blocks = [None] * (max_block_idx + 1)
for k, v in mapping.items():
if v == -1:
v = list(range(max_block_idx + 1))
for vv in v:
if vv < 0 or vv > max_block_idx:
continue
else:
if all_blocks[vv] is not None:
raise ValueError(
f'Block {vv} is assigned to both {k} and {all_blocks[vv]}.'
)
else:
all_blocks[vv] = k


def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any,
max_block_idx: int) -> dict:
act_ckpt_mod_to_blocks = {}
if act_ckpt_target is None or act_ckpt_target == []:
mod = top_module
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, str):
mod = get_act_ckpt_module(act_ckpt_target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, list):
for target in act_ckpt_target:
mod = get_act_ckpt_module(target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, dict):
for k, v in act_ckpt_target.items():
mod = get_act_ckpt_module(k)
block_ids = get_target_block_list(v, max_block_idx)
act_ckpt_mod_to_blocks[mod] = block_ids
else:
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}'
)

return act_ckpt_mod_to_blocks
29 changes: 21 additions & 8 deletions tests/models/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
@pytest.mark.gpu
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize('activation_checkpointing_target', [
'grouped_query_attention', [], ['grouped_query_attention'],
['mptblock', 'grouped_query_attention']
'grouped_query_attention', [], ['grouped_query_attention'], {
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}
])
def test_fsdp_act_checkpoint(activation_checkpointing: bool,
activation_checkpointing_target: Union[list, str]):
activation_checkpointing_target: Union[list, str,
dict]):
device = get_device('gpu')
model_cfg = {
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'n_layers': 3,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
Expand Down Expand Up @@ -59,10 +62,7 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
assert not isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0], CheckpointWrapper)
elif (not activation_checkpointing_target
) or activation_checkpointing_target == [
'mptblock', 'grouped_query_attention'
]:
elif (not activation_checkpointing_target):
module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[
0]._fsdp_wrapped_module
assert isinstance(module, CheckpointWrapper)
Expand All @@ -72,6 +72,19 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
elif activation_checkpointing_target == {
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[1]._fsdp_wrapped_module, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[2]._fsdp_wrapped_module.attn, CheckpointWrapper)
else:
raise ValueError(
f'Unknown activation_checkpointing_target: {activation_checkpointing_target}'
Expand Down
Loading