diff --git a/llmfoundry/models/utils/tp_strategy.py b/llmfoundry/models/utils/tp_strategy.py index 7c8754994d..c75f59d356 100644 --- a/llmfoundry/models/utils/tp_strategy.py +++ b/llmfoundry/models/utils/tp_strategy.py @@ -21,27 +21,36 @@ def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: } if tp_layers_in_model != TP_LAYERS: raise RuntimeError( - f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.' + f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.', ) # Generate layer plan layer_plan: dict[str, ParallelStyle] = {} for name, _ in model.named_modules(): - if name.split('.')[-2:] == ['ffn', 'up_proj']: + # Before the ffn layer starts, distribute the input data for proper TP use + # Inputs are currently sharded across the batch dimension (dim 0) as is done in standard DDP + # Inputs will be replicated across hidden dimension (dim 1) via allgather + if name.split('.')[-1] == 'ffn': + layer_plan[name] = PrepareModuleInput( + input_layouts=Shard(0), + desired_input_layouts=Replicate(), + use_local_output=True, + ) + # Shard the ffn.up_proj weight matrix across its columns + # Inputs are already replicated across each TP group + # Outputs will be sharded along the hidden dimension (dim 1) via allgather + elif name.split('.')[-2:] == ['ffn', 'up_proj']: layer_plan[name] = ColwiseParallel( input_layouts=Replicate(), output_layouts=Shard(-1), ) + # Shard the ffn.down_proj weight matrix across its rows + # Inputs are sharded along the hidden dimension (dim 1) + # Outputs will be sharded along batch dimension (dim 0) via allreduce elif name.split('.')[-2:] == ['ffn', 'down_proj']: layer_plan[name] = RowwiseParallel( input_layouts=Shard(-1), output_layouts=Shard(0), ) - elif name.split('.')[-1] == 'ffn': - layer_plan[name] = PrepareModuleInput( - input_layouts=Shard(0), - desired_input_layouts=Replicate(), - use_local_output=True, - ) return layer_plan