diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index b851aaa559..07e8f35747 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -48,6 +48,7 @@ models, optim, tokenizers, + tp, utils, ) from llmfoundry._version import __version__ @@ -87,5 +88,6 @@ 'models', 'optim', 'tokenizers', + 'tp', 'utils', ] diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 14b7980d57..29878714f6 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -5,6 +5,7 @@ import os import time import warnings +from copy import deepcopy from typing import Any, Optional, Union import torch @@ -43,6 +44,7 @@ build_save_planner, build_scheduler, build_tokenizer, + build_tp_strategies, ) from llmfoundry.utils.config_utils import ( TRAIN_CONFIG_KEYS, @@ -329,16 +331,27 @@ def train(cfg: DictConfig) -> Trainer: changing autoresume default to True...', ) - # Warn if fsdp is enabled but user only has 1 GPU - if dist.get_world_size() == 1 and fsdp_config is not None: + # Optional tp config + tp_config: Optional[dict[str, Any]] = train_cfg.tp_config + + # Warn if FSDP or TP is enabled but user only has 1 GPU + if dist.get_world_size( + ) == 1 and (fsdp_config is not None or tp_config is not None): + parallelism = '' + if fsdp_config is not None: + parallelism += 'FSDP' + if tp_config is not None: + parallelism += '+TP' if fsdp_config is not None else 'TP' warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.', + f'{parallelism} is not applicable for single-GPU training. Reverting to DDP.', ) fsdp_config = None + tp_config = None # Initialize context - init_context = process_init_device(model_config, fsdp_config) + init_context = process_init_device(model_config, fsdp_config, tp_config) logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) + logged_cfg.update({'tp_config': deepcopy(tp_config)}, merge=True) # Build tokenizer log.info('Building tokenizer...') @@ -502,6 +515,15 @@ def train(cfg: DictConfig) -> Trainer: _log_num_params(model, logged_cfg) + # TP config + if tp_config is not None: + strategy = tp_config.pop('strategy', None) + assert isinstance(strategy, str), '`strategy` must be in `tp_config`.' + tp_config['layer_plan'] = build_tp_strategies(strategy, model) + + # Parallelism config + parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config} + # Optimizer optimizer_name: str = train_cfg.optimizer.pop('name') optimizer_cfg = train_cfg.optimizer @@ -546,7 +568,7 @@ def train(cfg: DictConfig) -> Trainer: precision=train_cfg.precision, algorithms=algorithms, device_train_microbatch_size=train_cfg.device_train_microbatch_size, - parallelism_config={'fsdp': fsdp_config}, + parallelism_config=parallelism_config, save_folder=train_cfg.save_folder, save_filename=save_filename, save_latest_filename=save_latest_filename, diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index cb2455a760..850c4f3bbd 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -7,6 +7,7 @@ from composer.models import ComposerModel from composer.optim import ComposerScheduler from torch.distributed.checkpoint import LoadPlanner, SavePlanner +from torch.distributed.tensor.parallel.style import ParallelStyle from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader from torch.utils.data import Dataset @@ -389,6 +390,26 @@ description=_save_planners_description, ) +_tp_strategies_description = ( + """The tp_strategies registry is used to register strategies for tensor parallelism. + + Args: + model (ComposerModel): The model. + + Returns: + layer_plan (Dict[str, ParallelStyle]): The plan used to parallelize the model. + model (ComposerModel): The model. + """ +) + +tp_strategies = create_registry( + 'llmfoundry', + 'tp_strategies', + generic_type=Callable[[ComposerModel], dict[str, ParallelStyle]], + entry_points=True, + description=_tp_strategies_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -416,4 +437,5 @@ 'config_transforms', 'load_planners', 'save_planners', + 'tp_strategies', ] diff --git a/llmfoundry/tp/__init__.py b/llmfoundry/tp/__init__.py new file mode 100644 index 0000000000..323ae23727 --- /dev/null +++ b/llmfoundry/tp/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.registry import tp_strategies +from llmfoundry.tp.ffn_tp_strategy import ffn_tp_strategy + +tp_strategies.register('ffn', func=ffn_tp_strategy) + +__all__ = [ + 'ffn_tp_strategy', +] diff --git a/llmfoundry/tp/ffn_tp_strategy.py b/llmfoundry/tp/ffn_tp_strategy.py new file mode 100644 index 0000000000..1de92ef6ae --- /dev/null +++ b/llmfoundry/tp/ffn_tp_strategy.py @@ -0,0 +1,56 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from composer.models import ComposerModel +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) +from torch.distributed.tensor.parallel.style import ParallelStyle + + +def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: + TP_LAYERS = {'ffn', 'ffn.up_proj', 'ffn.down_proj'} + + # Validate that all TP_LAYERS are in model + tp_layers_in_model = { + layer for layer in TP_LAYERS for name, _ in model.named_modules() + if layer in name + } + if tp_layers_in_model != TP_LAYERS: + raise RuntimeError( + f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.', + ) + + # Generate layer plan + layer_plan: dict[str, ParallelStyle] = {} + for name, _ in model.named_modules(): + # Before the ffn layer starts, distribute the input data for proper TP use + # Inputs are currently sharded across the batch dimension (dim 0) as is done in standard DDP + # Inputs will be replicated across hidden dimension (dim 1) via allgather + if name.split('.')[-1] == 'ffn': + layer_plan[name] = PrepareModuleInput( + input_layouts=Shard(0), + desired_input_layouts=Replicate(), + use_local_output=True, + ) + # Shard the ffn.up_proj weight matrix across its columns + # Inputs are already replicated across each TP group + # Outputs will be sharded along the hidden dimension (dim 1) via allgather + elif name.split('.')[-2:] == ['ffn', 'up_proj']: + layer_plan[name] = ColwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(-1), + ) + # Shard the ffn.down_proj weight matrix across its rows + # Inputs are sharded along the hidden dimension (dim 1) + # Outputs will be sharded along batch dimension (dim 0) via allreduce + elif name.split('.')[-2:] == ['ffn', 'down_proj']: + layer_plan[name] = RowwiseParallel( + input_layouts=Shard(-1), + output_layouts=Shard(0), + ) + + return layer_plan diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index f2d5cfc0f7..687b21b46d 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -7,14 +7,9 @@ import logging import os import re +import warnings from collections import OrderedDict -from typing import ( - Any, - ContextManager, - Iterable, - Optional, - Union, -) +from typing import Any, ContextManager, Iterable, Optional, Union import torch from composer.core import Algorithm, Callback, Evaluator @@ -25,6 +20,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.distributed.checkpoint import LoadPlanner, SavePlanner +from torch.distributed.tensor.parallel.style import ParallelStyle from torch.optim.optimizer import Optimizer from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -37,6 +33,7 @@ ) from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry +from llmfoundry.utils.warnings import experimental_function log = logging.getLogger(__name__) @@ -52,6 +49,7 @@ 'build_tokenizer', 'build_composer_model', 'build_metric', + 'build_tp_strategies', ] @@ -701,3 +699,20 @@ def _validate_cfg(icl_cfg: dict[str, Any]): ) return evaluators, logger_keys + + +@experimental_function('Tensor Parallelism') +def build_tp_strategies( + name: str, + model: ComposerModel, +) -> dict[str, ParallelStyle]: + + warnings.warn( + 'Checkpointing is not currently supported for tensor parallelism due to this pytorch bug: https://github.com/pytorch/pytorch/issues/134095#issuecomment-2345018244', + ) + return construct_from_registry( + name=name, + registry=registry.tp_strategies, + partial_function=False, + kwargs={'model': model}, + ) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index ba5c5941b8..c22495993c 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -120,6 +120,7 @@ class TrainConfig: # Distributed training parameters dist_timeout: Union[int, float] = 600.0 fsdp_config: Optional[dict[str, Any]] = None + tp_config: Optional[dict[str, Any]] = None # Evaluation parameters eval_interval: Union[int, str] = 1 @@ -501,7 +502,11 @@ def update_batch_size_info(cfg: dict[str, Any]) -> dict[str, Any]: return cfg -def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]): +def process_init_device( + model_cfg: dict[str, Any], + fsdp_config: Optional[dict] = None, + tp_config: Optional[dict] = None, +): # Restrict model init_device to 'meta' and 'cpu', # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors # when multiple GPUs are available. @@ -533,6 +538,13 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]): # Set defaults for mixed initialization fsdp_config.setdefault('load_monolith_rank0_only', True) + # Check we are not using tensor parallelism with MoEs + if tp_config is not None and 'ffn_config' in model_cfg and model_cfg[ + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: + raise ValueError( + 'Tensor Parallelism is not currently supported for MoE models.', + ) + # Set ffn_config.device_mesh using fsdp_config if fsdp_config is not None and 'ffn_config' in model_cfg and model_cfg[ 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: diff --git a/tests/test_registry.py b/tests/test_registry.py index 5108a7d46c..90ef3bfaac 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -47,6 +47,7 @@ def test_expected_registries_exist(): 'config_transforms', 'load_planners', 'save_planners', + 'tp_strategies', } assert existing_registries == expected_registry_names diff --git a/tests/tp/__init__.py b/tests/tp/__init__.py new file mode 100644 index 0000000000..80950cb7b4 --- /dev/null +++ b/tests/tp/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/tp/test_tp_strategies.py b/tests/tp/test_tp_strategies.py new file mode 100644 index 0000000000..fd2fa384ce --- /dev/null +++ b/tests/tp/test_tp_strategies.py @@ -0,0 +1,133 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +from omegaconf import OmegaConf as om +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) + +from llmfoundry.command_utils.train import train +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM +from llmfoundry.utils.builders import build_tp_strategies +from llmfoundry.utils.config_utils import process_init_device +from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg + + +@pytest.mark.gpu +@pytest.mark.filterwarnings( + 'ignore:tp_strategies is experimental and may change with future versions.', +) +def test_ffn_tp_strategy(): + """Test the FFN tensor parallelism strategy is correct.""" + # Create layer plan from fnn tp_strategy + tp_config = { + 'strategy': 'ffn', + } + + model_cfg = { + 'name': 'mpt_causal_lm', + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 3, + 'expansion_ratio': 1, + 'max_seq_len': 16, + 'vocab_size': 50368, + } + model = ComposerMPTCausalLM(**model_cfg) + layer_plan = build_tp_strategies(tp_config['strategy'], model) + + # Expected layer plan + _expected_layer_plan = { + 'ffn': + PrepareModuleInput( + input_layouts=Shard(0), + desired_input_layouts=Replicate(), + use_local_output=True, + ), + 'ffn.down_proj': + RowwiseParallel( + input_layouts=Shard(-1), + output_layouts=Shard(0), + ), + 'ffn.up_proj': + ColwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(-1), + ), + } + expected_layer_plan = { + f'model.transformer.blocks.{layer_idx}.{name}': layer_plan + for name, layer_plan in _expected_layer_plan.items() + for layer_idx in range(model_cfg['n_layers']) + } + + # Compare expected and actual layer plans + for (n1, lp1), (n2, lp2) in zip( + sorted(expected_layer_plan.items()), + sorted(layer_plan.items()), + ): + assert n1 == n2 + assert type(lp1) == type(lp2) + if isinstance( + lp1, + PrepareModuleInput, + ) and isinstance(lp2, PrepareModuleInput): + assert lp1.input_layouts == lp2.input_layouts + assert lp1.desired_input_layouts == lp2.desired_input_layouts + assert lp1.use_local_output == lp2.use_local_output + elif ( + isinstance(lp1, ColwiseParallel) and + isinstance(lp2, ColwiseParallel) + ) or ( + isinstance(lp1, RowwiseParallel) and + isinstance(lp2, RowwiseParallel) + ): + assert lp1.input_layouts == lp2.input_layouts + assert lp1.output_layouts == lp2.output_layouts + assert lp1.use_local_output == lp2.use_local_output + else: + raise ValueError(f'Layer plan of wrong type: {type(layer_plan)}') + + +@pytest.mark.gpu +def test_no_tp_with_one_gpu(): + """Test that when we have one GPU, we use DDP and not FSDP-TP.""" + with TemporaryDirectory() as tmp_path: + # Make `train_cfg`` with a tensor parallelism strategy + dataset_name = create_c4_dataset_xxsmall(Path(tmp_path)) + train_cfg = gpt_tiny_cfg(dataset_name, 'gpu') + train_cfg.tp_config = {'strategy': 'ffn'} + + # Expect a warning + with pytest.warns( + UserWarning, + match= + r'FSDP\+TP is not applicable for single-GPU training. Reverting to DDP.', + ): + train(train_cfg) + + +@pytest.mark.gpu # use gpu because `megablocks` only installed with `gpu` dependencies +def test_no_tp_with_moes(): + """Test that tensor parallelism is not compatible with MoEs.""" + # Make `cfg` for MoE model, fsdp, and tp + train_cfg_path: str = 'scripts/train/yamls/pretrain/testing-moe.yaml' + with open(train_cfg_path, 'r', encoding='utf-8') as f: + train_cfg = om.load(f) + model_cfg = train_cfg.model + fsdp_cfg = train_cfg.fsdp_config + tp_cfg = {'strategy': 'ffn'} + + # Expect an error + with pytest.raises( + ValueError, + match='Tensor Parallelism is not currently supported for MoE models.', + ): + process_init_device(model_cfg, fsdp_cfg, tp_cfg)