Skip to content

Commit

Permalink
Tensor Parallelism (#1521)
Browse files Browse the repository at this point in the history
Co-authored-by: Eitan Turok <[email protected]>
Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent 3b1fc4a commit ee45600
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 13 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
models,
optim,
tokenizers,
tp,
utils,
)
from llmfoundry._version import __version__
Expand Down Expand Up @@ -87,5 +88,6 @@
'models',
'optim',
'tokenizers',
'tp',
'utils',
]
32 changes: 27 additions & 5 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
import warnings
from copy import deepcopy
from typing import Any, Optional, Union

import torch
Expand Down Expand Up @@ -43,6 +44,7 @@
build_save_planner,
build_scheduler,
build_tokenizer,
build_tp_strategies,
)
from llmfoundry.utils.config_utils import (
TRAIN_CONFIG_KEYS,
Expand Down Expand Up @@ -329,16 +331,27 @@ def train(cfg: DictConfig) -> Trainer:
changing autoresume default to True...',
)

# Warn if fsdp is enabled but user only has 1 GPU
if dist.get_world_size() == 1 and fsdp_config is not None:
# Optional tp config
tp_config: Optional[dict[str, Any]] = train_cfg.tp_config

# Warn if FSDP or TP is enabled but user only has 1 GPU
if dist.get_world_size(
) == 1 and (fsdp_config is not None or tp_config is not None):
parallelism = ''
if fsdp_config is not None:
parallelism += 'FSDP'
if tp_config is not None:
parallelism += '+TP' if fsdp_config is not None else 'TP'
warnings.warn(
'FSDP is not applicable for single-GPU training. Reverting to DDP.',
f'{parallelism} is not applicable for single-GPU training. Reverting to DDP.',
)
fsdp_config = None
tp_config = None

# Initialize context
init_context = process_init_device(model_config, fsdp_config)
init_context = process_init_device(model_config, fsdp_config, tp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)
logged_cfg.update({'tp_config': deepcopy(tp_config)}, merge=True)

# Build tokenizer
log.info('Building tokenizer...')
Expand Down Expand Up @@ -502,6 +515,15 @@ def train(cfg: DictConfig) -> Trainer:

_log_num_params(model, logged_cfg)

# TP config
if tp_config is not None:
strategy = tp_config.pop('strategy', None)
assert isinstance(strategy, str), '`strategy` must be in `tp_config`.'
tp_config['layer_plan'] = build_tp_strategies(strategy, model)

# Parallelism config
parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config}

# Optimizer
optimizer_name: str = train_cfg.optimizer.pop('name')
optimizer_cfg = train_cfg.optimizer
Expand Down Expand Up @@ -546,7 +568,7 @@ def train(cfg: DictConfig) -> Trainer:
precision=train_cfg.precision,
algorithms=algorithms,
device_train_microbatch_size=train_cfg.device_train_microbatch_size,
parallelism_config={'fsdp': fsdp_config},
parallelism_config=parallelism_config,
save_folder=train_cfg.save_folder,
save_filename=save_filename,
save_latest_filename=save_latest_filename,
Expand Down
22 changes: 22 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from composer.models import ComposerModel
from composer.optim import ComposerScheduler
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataloader
from torch.utils.data import Dataset
Expand Down Expand Up @@ -389,6 +390,26 @@
description=_save_planners_description,
)

_tp_strategies_description = (
"""The tp_strategies registry is used to register strategies for tensor parallelism.
Args:
model (ComposerModel): The model.
Returns:
layer_plan (Dict[str, ParallelStyle]): The plan used to parallelize the model.
model (ComposerModel): The model.
"""
)

tp_strategies = create_registry(
'llmfoundry',
'tp_strategies',
generic_type=Callable[[ComposerModel], dict[str, ParallelStyle]],
entry_points=True,
description=_tp_strategies_description,
)

__all__ = [
'loggers',
'callbacks',
Expand Down Expand Up @@ -416,4 +437,5 @@
'config_transforms',
'load_planners',
'save_planners',
'tp_strategies',
]
11 changes: 11 additions & 0 deletions llmfoundry/tp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.registry import tp_strategies
from llmfoundry.tp.ffn_tp_strategy import ffn_tp_strategy

tp_strategies.register('ffn', func=ffn_tp_strategy)

__all__ = [
'ffn_tp_strategy',
]
56 changes: 56 additions & 0 deletions llmfoundry/tp/ffn_tp_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.models import ComposerModel
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
)
from torch.distributed.tensor.parallel.style import ParallelStyle


def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]:
TP_LAYERS = {'ffn', 'ffn.up_proj', 'ffn.down_proj'}

# Validate that all TP_LAYERS are in model
tp_layers_in_model = {
layer for layer in TP_LAYERS for name, _ in model.named_modules()
if layer in name
}
if tp_layers_in_model != TP_LAYERS:
raise RuntimeError(
f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.',
)

# Generate layer plan
layer_plan: dict[str, ParallelStyle] = {}
for name, _ in model.named_modules():
# Before the ffn layer starts, distribute the input data for proper TP use
# Inputs are currently sharded across the batch dimension (dim 0) as is done in standard DDP
# Inputs will be replicated across hidden dimension (dim 1) via allgather
if name.split('.')[-1] == 'ffn':
layer_plan[name] = PrepareModuleInput(
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
)
# Shard the ffn.up_proj weight matrix across its columns
# Inputs are already replicated across each TP group
# Outputs will be sharded along the hidden dimension (dim 1) via allgather
elif name.split('.')[-2:] == ['ffn', 'up_proj']:
layer_plan[name] = ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(-1),
)
# Shard the ffn.down_proj weight matrix across its rows
# Inputs are sharded along the hidden dimension (dim 1)
# Outputs will be sharded along batch dimension (dim 0) via allreduce
elif name.split('.')[-2:] == ['ffn', 'down_proj']:
layer_plan[name] = RowwiseParallel(
input_layouts=Shard(-1),
output_layouts=Shard(0),
)

return layer_plan
29 changes: 22 additions & 7 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@
import logging
import os
import re
import warnings
from collections import OrderedDict
from typing import (
Any,
ContextManager,
Iterable,
Optional,
Union,
)
from typing import Any, ContextManager, Iterable, Optional, Union

import torch
from composer.core import Algorithm, Callback, Evaluator
Expand All @@ -25,6 +20,7 @@
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand All @@ -37,6 +33,7 @@
)
from llmfoundry.utils.config_utils import to_dict_container, to_list_container
from llmfoundry.utils.registry_utils import construct_from_registry
from llmfoundry.utils.warnings import experimental_function

log = logging.getLogger(__name__)

Expand All @@ -52,6 +49,7 @@
'build_tokenizer',
'build_composer_model',
'build_metric',
'build_tp_strategies',
]


Expand Down Expand Up @@ -701,3 +699,20 @@ def _validate_cfg(icl_cfg: dict[str, Any]):
)

return evaluators, logger_keys


@experimental_function('Tensor Parallelism')
def build_tp_strategies(
name: str,
model: ComposerModel,
) -> dict[str, ParallelStyle]:

warnings.warn(
'Checkpointing is not currently supported for tensor parallelism due to this pytorch bug: https://github.com/pytorch/pytorch/issues/134095#issuecomment-2345018244',
)
return construct_from_registry(
name=name,
registry=registry.tp_strategies,
partial_function=False,
kwargs={'model': model},
)
14 changes: 13 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class TrainConfig:
# Distributed training parameters
dist_timeout: Union[int, float] = 600.0
fsdp_config: Optional[dict[str, Any]] = None
tp_config: Optional[dict[str, Any]] = None

# Evaluation parameters
eval_interval: Union[int, str] = 1
Expand Down Expand Up @@ -501,7 +502,11 @@ def update_batch_size_info(cfg: dict[str, Any]) -> dict[str, Any]:
return cfg


def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]):
def process_init_device(
model_cfg: dict[str, Any],
fsdp_config: Optional[dict] = None,
tp_config: Optional[dict] = None,
):
# Restrict model init_device to 'meta' and 'cpu',
# using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors
# when multiple GPUs are available.
Expand Down Expand Up @@ -533,6 +538,13 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('load_monolith_rank0_only', True)

# Check we are not using tensor parallelism with MoEs
if tp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
raise ValueError(
'Tensor Parallelism is not currently supported for MoE models.',
)

# Set ffn_config.device_mesh using fsdp_config
if fsdp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
Expand Down
1 change: 1 addition & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_expected_registries_exist():
'config_transforms',
'load_planners',
'save_planners',
'tp_strategies',
}

assert existing_registries == expected_registry_names
Expand Down
2 changes: 2 additions & 0 deletions tests/tp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
Loading

0 comments on commit ee45600

Please sign in to comment.