Skip to content

Commit

Permalink
explain with comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 26, 2024
1 parent 8c0135d commit 0156dd2
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions llmfoundry/models/utils/tp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,36 @@ def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]:
}
if tp_layers_in_model != TP_LAYERS:
raise RuntimeError(
f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.'
f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.',
)

# Generate layer plan
layer_plan: dict[str, ParallelStyle] = {}
for name, _ in model.named_modules():
if name.split('.')[-2:] == ['ffn', 'up_proj']:
# Before the ffn layer starts, distribute the input data for proper TP use
# Inputs are currently sharded across the batch dimension (dim 0) as is done in standard DDP
# Inputs will be replicated across hidden dimension (dim 1) via allgather
if name.split('.')[-1] == 'ffn':
layer_plan[name] = PrepareModuleInput(
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
)
# Shard the ffn.up_proj weight matrix across its columns
# Inputs are already replicated across each TP group
# Outputs will be sharded along the hidden dimension (dim 1) via allgather
elif name.split('.')[-2:] == ['ffn', 'up_proj']:
layer_plan[name] = ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(-1),
)
# Shard the ffn.down_proj weight matrix across its rows
# Inputs are sharded along the hidden dimension (dim 1)
# Outputs will be sharded along batch dimension (dim 0) via allreduce
elif name.split('.')[-2:] == ['ffn', 'down_proj']:
layer_plan[name] = RowwiseParallel(
input_layouts=Shard(-1),
output_layouts=Shard(0),
)
elif name.split('.')[-1] == 'ffn':
layer_plan[name] = PrepareModuleInput(
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
)

return layer_plan

0 comments on commit 0156dd2

Please sign in to comment.