diff --git a/llmfoundry/models/utils/tp_strategy.py b/llmfoundry/models/utils/tp_strategy.py index 748ccf0c55..2c48160750 100644 --- a/llmfoundry/models/utils/tp_strategy.py +++ b/llmfoundry/models/utils/tp_strategy.py @@ -14,7 +14,7 @@ def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: TP_LAYERS = {'up_proj', 'down_proj'} - # validate that all TP_LAYERS are in model + # Validate that all TP_LAYERS are in model tp_layers_in_model = { layer for layer in TP_LAYERS for name, _ in model.named_modules() if layer in name