diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 6749d3abdd..cc0ff353c5 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -18,7 +18,7 @@ TraceHandler, cyclic_schedule, ) -from composer.utils import dist, get_device, reproducibility +from composer.utils import TPConfig, dist, get_device, reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om @@ -518,10 +518,7 @@ def train(cfg: DictConfig) -> Trainer: if tp_config is not None: strategy = tp_config.pop('strategy', None) layer_plan = build_tp_strategies(strategy, model) - tp_config = { - 'layer_plan': layer_plan, - 'tensor_parallel_degree': tp_config['tensor_parallel_degree'], - } + tp_config = TPConfig(**tp_config, layer_plan=layer_plan) # Parallelism config parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config}