diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 2eef4bdc30..ac7f0e7258 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -19,8 +19,11 @@ TraceHandler, cyclic_schedule, ) -from composer.utils import (FSDPConfig, ParallelismConfig, TPConfig, dist, - get_device, reproducibility,) +from composer.utils import ( + dist, + get_device, + reproducibility, +) from omegaconf import DictConfig from omegaconf import OmegaConf as om diff --git a/llmfoundry/models/utils/tp_strategy.py b/llmfoundry/models/utils/tp_strategy.py index 1d7a199efc..f7929686c6 100644 --- a/llmfoundry/models/utils/tp_strategy.py +++ b/llmfoundry/models/utils/tp_strategy.py @@ -1,12 +1,13 @@ - # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 from composer.models import ComposerModel from torch.distributed._tensor import Replicate, Shard -from torch.distributed.tensor.parallel import (ColwiseParallel, - PrepareModuleInput, - RowwiseParallel,) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) from torch.distributed.tensor.parallel.style import ParallelStyle @@ -14,10 +15,10 @@ def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: TP_LAYERS = {'up_proj', 'down_proj'} # validate that all TP_LAYERS are in model - tp_layers_in_model = set([ + tp_layers_in_model = { layer for layer in TP_LAYERS for name, _ in model.named_modules() if layer in name - ]) + } assert tp_layers_in_model == TP_LAYERS, f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.' # generate layer plan diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 3d2d152715..cc7415ba10 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -503,8 +503,9 @@ def update_batch_size_info(cfg: dict[str, Any]) -> dict[str, Any]: def process_init_device( - model_cfg: dict[str, Any], fsdp_config: Optional[dict], - tp_config: Optional[dict] + model_cfg: dict[str, Any], + fsdp_config: Optional[dict], + tp_config: Optional[dict], ): # Restrict model init_device to 'meta' and 'cpu', # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors diff --git a/tests/models/utils/test_tp_strategy.py b/tests/models/utils/test_tp_strategy.py index 6893550e6a..cb015c53bc 100644 --- a/tests/models/utils/test_tp_strategy.py +++ b/tests/models/utils/test_tp_strategy.py @@ -2,9 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from torch.distributed._tensor import Replicate, Shard -from torch.distributed.tensor.parallel import (ColwiseParallel, - PrepareModuleInput, - RowwiseParallel,) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM from llmfoundry.utils.builders import build_tp_strategy @@ -58,7 +60,8 @@ def test_ffn_tp_strategy_layer_plan(): # Compare expected and actual layer plans for (n1, lp1), (n2, lp2) in zip( - sorted(expected_layer_plan.items()), sorted(layer_plan.items()) + sorted(expected_layer_plan.items()), + sorted(layer_plan.items()), ): assert n1 == n2 assert type(lp1) == type(lp2)