diff --git a/llmfoundry/models/utils/tp_strategy.py b/llmfoundry/models/utils/tp_strategy.py index 2e2b253c87..68fd73762b 100644 --- a/llmfoundry/models/utils/tp_strategy.py +++ b/llmfoundry/models/utils/tp_strategy.py @@ -26,7 +26,6 @@ def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: elif name.split('.')[-2:] == ['ffn', 'down_proj']: layer_plan[name] = RowwiseParallel( input_layouts = Shard(-1), - # output_layouts = Replicate(), output_layouts = Shard(0), ) elif name.split('.')[-1] == 'ffn':