Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Feb 7, 2024
1 parent 95b560c commit dbd72b8
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
18 changes: 17 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ def __init__(self, config: MPTConfig):
**config.to_dict(),
) for _ in range(config.n_layers)
])

def pass_on_block_idx(parent: nn.Module):
if not hasattr(parent, 'block_idx') or not hasattr(
parent, 'max_block_idx'):
return
for child in parent.children():
child.block_idx = parent.block_idx
child.max_block_idx = parent.max_block_idx
if child.children():
pass_on_block_idx(child)

for i, block in enumerate(self.blocks):
block.block_idx = i
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)

self.norm_f = norm_class(config.d_model, device=config.init_device)

self.rope = config.attn_config['rope']
Expand Down Expand Up @@ -949,7 +965,7 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_target = getattr(self.config,
'activation_checkpointing_target', None)
act_ckpt_mod_to_blocks = {}
if act_ckpt_target is None:
if act_ckpt_target is None or act_ckpt_target == []:
mod = MPTBlock
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, str):
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
"""Check if the block ids in the mapping overlap with each other."""
all_blocks = [None] * (max_block_idx + 1)
for k, v in mapping.items():
if v == -1:
v = list(range(max_block_idx + 1))
for vv in v:
if vv < 0 or vv > max_block_idx:
continue
Expand Down
17 changes: 10 additions & 7 deletions tests/models/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize('activation_checkpointing_target', [
'grouped_query_attention', [], ['grouped_query_attention'], {
'mptblock': 0,
'grouped_query_attention': 1
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}
])
def test_fsdp_act_checkpoint(activation_checkpointing: bool,
Expand All @@ -30,7 +30,7 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'n_layers': 3,
'expansion_ratio': 1,
'max_seq_len': 16,
'vocab_size': 50368,
Expand Down Expand Up @@ -73,15 +73,18 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool,
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
elif activation_checkpointing_target == {
'mptblock': 0,
'grouped_query_attention': 1
'mptblock': [1],
'grouped_query_attention': 'first-1, last-1'
}:
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[0]._fsdp_wrapped_module, CheckpointWrapper)
blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[1]._fsdp_wrapped_module, CheckpointWrapper)
assert isinstance(
trainer.state.model.model._fsdp_wrapped_module.transformer.
blocks[1]._fsdp_wrapped_module.attn, CheckpointWrapper)
blocks[2]._fsdp_wrapped_module.attn, CheckpointWrapper)
else:
raise ValueError(
f'Unknown activation_checkpointing_target: {activation_checkpointing_target}'
Expand Down

0 comments on commit dbd72b8

Please sign in to comment.