Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fully configurable activation checkpointing #951

Merged
merged 19 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 93 additions & 39 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
Expand Down Expand Up @@ -55,17 +56,11 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import \
FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP
from llmfoundry.models.layers.ffn import build_ffn as build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.mpt.configuration_mpt import MPTConfig

# NOTE: All utils are imported directly even if unused so that
Expand All @@ -87,6 +82,10 @@
MODEL_INIT_REGISTRY,
)

from llmfoundry.models.utils.act_ckpt import (get_act_ckpt_module,
get_target_block_list,
check_mapping_blocks_overlap)

try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func
except:
Expand Down Expand Up @@ -352,6 +351,22 @@ def __init__(self, config: MPTConfig):
**config.to_dict(),
) for _ in range(config.n_layers)
])

def pass_on_block_idx(parent: nn.Module):
cli99 marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(parent, 'block_idx') or not hasattr(
parent, 'max_block_idx'):
return
cli99 marked this conversation as resolved.
Show resolved Hide resolved
for child in parent.children():
child.block_idx = parent.block_idx
child.max_block_idx = parent.max_block_idx
if child.children():
pass_on_block_idx(child)

for i, block in enumerate(self.blocks):
block.block_idx = i
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)
cli99 marked this conversation as resolved.
Show resolved Hide resolved

self.norm_f = norm_class(config.d_model, device=config.init_device)

self.rope = config.attn_config['rope']
Expand Down Expand Up @@ -908,41 +923,80 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']
if isinstance(act_ckpt_list, str):
act_ckpt_list = [act_ckpt_list]
elif not isinstance(act_ckpt_list, list):
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}'
"""The MPT activation checkpointing (act ckpt) function.

When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below:
1. null (or no such field), then ack ckpt the whole MPTBlock on all layers
cli99 marked this conversation as resolved.
Show resolved Hide resolved
2. a list of modules to act ckpt, e.g.,
activation_checkpointing_target:
- grouped_query_attention
- mptmlp
, then ack ckpt all the modules in the list on all layers
3. a dictionary of module name with target_blocks, e.g.,
cli99 marked this conversation as resolved.
Show resolved Hide resolved
activation_checkpointing_target:
{
"mptblock": target_blocks_1,
"grouped_query_attention": target_blocks_2
}
target_blocks (target_blocks_1, target_blocks_2 above) can be:
- a single integer n means the first n transformer block will be candidates for act ckpt
- a string of first-n, middle-m, last-k, range-i-j means the first n, the middle m, the last k, or the range [i, j) layers are candidates for act ckpt. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks are candidates for act ckpt.
middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)``
- a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block is the candidate for act ckpt. [2, 3] means the second and third transformer block are candidates.
- a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j
cli99 marked this conversation as resolved.
Show resolved Hide resolved

An example in yaml config file:,
fsdp_config:
activation_checkpointing: true
model:
activation_checkpointing_target:
{
"mptblock": 'first-5',
"grouped_query_attention": 'last-35'
}
does full act ckpt on the first 5 layers and then ack ckpt the grouped_query_attention on the last 35 layers
cli99 marked this conversation as resolved.
Show resolved Hide resolved
"""
if not hasattr(module, 'block_idx'):
log.debug(
f'No activating checkpointing for {module.__class__.__name__}, only transformer block or its submodules are eligible for activation checkpointing.'
cli99 marked this conversation as resolved.
Show resolved Hide resolved
)

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
return False

act_ckpt_target = getattr(self.config,
'activation_checkpointing_target', None)
act_ckpt_mod_to_blocks = {}
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
if act_ckpt_target is None or act_ckpt_target == []:
mod = MPTBlock
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, str):
mod = get_act_ckpt_module(act_ckpt_target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, list):
for target in act_ckpt_target:
mod = get_act_ckpt_module(target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, dict):
for k, v in act_ckpt_target.items():
mod = get_act_ckpt_module(k)
block_ids = get_target_block_list(v, module.max_block_idx)
act_ckpt_mod_to_blocks[mod] = block_ids
log.info(
'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).'
f'for module {mod.__name__}, target_blocks is set as {v}, activation checkpointing is applied to {block_ids} blocks.'
cli99 marked this conversation as resolved.
Show resolved Hide resolved
)
return isinstance(module, MPTBlock)

mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) +
list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return isinstance(module, mod_types)
else:
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}'
)

check_mapping_blocks_overlap(act_ckpt_mod_to_blocks,
module.max_block_idx)

for k in act_ckpt_mod_to_blocks.keys():
if isinstance(module, k):
blocks = act_ckpt_mod_to_blocks[k]
return True if blocks == -1 else module.block_idx in blocks

return False

def prepare_inputs_for_generation(
self,
Expand Down
109 changes: 109 additions & 0 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_type = NORM_CLASS_REGISTRY[mod_name]
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return mod_type


def parse_ele_str(ele: str, max_block_idx: int) -> list:
"""Parse a string in target_blocks and return a list of block ids to add.

Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
to the first n, the middle m, the last k, and the range [i, j).
"""
to_add = None
if ele.startswith('first-'):
assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
elif ele.startswith('last-'):
assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(
range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
elif ele.startswith('middle-'):
assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
num = int(ele[7:])
start = max(max_block_idx // 2 - num // 2, 0)
end = min(start + num, max_block_idx + 1)
to_add = list(range(start, end))
elif ele.startswith('range-'):
r = ele[6:].split('-')
assert len(r) == 2, f'Invalid target_blocks element {ele}'
start, end = int(r[0]), int(r[1])
start = max(start, 0)
end = min(end, max_block_idx + 1)
to_add = list(range(start, end))
else:
raise ValueError(f'Invalid target_blocks element {ele}')
return to_add


def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
"""Parse the user input and return a list of block ids."""
candidate_block_ids = []
if isinstance(target_blocks, int):
candidate_block_ids = list(range(target_blocks))
elif isinstance(target_blocks, list):
for ele in target_blocks:
if isinstance(ele, int):
candidate_block_ids.append(ele)
elif isinstance(ele, str):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}'
)
elif isinstance(target_blocks, str):
target_blocks = target_blocks.replace(' ', '')
for ele in target_blocks.split(','):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}'
)

candidate_block_ids = list(set(candidate_block_ids))
return candidate_block_ids


def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
"""Check if the block ids in the mapping overlap with each other."""
all_blocks = [None] * (max_block_idx + 1)
for k, v in mapping.items():
if v == -1:
v = list(range(max_block_idx + 1))
for vv in v:
if vv < 0 or vv > max_block_idx:
continue
else:
if all_blocks[vv] is not None:
raise ValueError(
f'Block {vv} is assigned to both {k} and {all_blocks[vv]}.'
)
else:
all_blocks[vv] = k
29 changes: 21 additions & 8 deletions tests/models/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
@pytest.mark.gpu
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize('activation_checkpointing_target', [
'grouped_query_attention', [], ['grouped_query_attention'],
['mptblock', 'grouped_query_attention']
'grouped_query_attention', [], ['grouped_query_attention'], {
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}
])
def test_fsdp_act_checkpoint(activation_checkpointing: bool,
activation_checkpointing_target: Union[list, str]):
activation_checkpointing_target: Union[list, str,
dict]):
device = get_device('gpu')
model_cfg = {
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'n_layers': 3,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
Expand Down Expand Up @@ -59,10 +62,7 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
assert not isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0], CheckpointWrapper)
elif (not activation_checkpointing_target
) or activation_checkpointing_target == [
'mptblock', 'grouped_query_attention'
]:
elif (not activation_checkpointing_target):
module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[
0]._fsdp_wrapped_module
assert isinstance(module, CheckpointWrapper)
Expand All @@ -72,6 +72,19 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
elif activation_checkpointing_target == {
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[1]._fsdp_wrapped_module, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[2]._fsdp_wrapped_module.attn, CheckpointWrapper)
else:
raise ValueError(
f'Unknown activation_checkpointing_target: {activation_checkpointing_target}'
Expand Down
Loading