Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 25, 2024
1 parent c2d309a commit 6d65a29
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 41 deletions.
8 changes: 5 additions & 3 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
TraceHandler,
cyclic_schedule,
)
from composer.utils import dist, get_device, reproducibility, ParallelismConfig, TPConfig, FSDPConfig
from composer.utils import (FSDPConfig, ParallelismConfig, TPConfig, dist,
get_device, reproducibility,)
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

Expand Down Expand Up @@ -335,7 +336,8 @@ def train(cfg: DictConfig) -> Trainer:
tp_config: Optional[dict[str, Any]] = train_cfg.tp_config

# Warn if FSDP or TP is enabled but user only has 1 GPU
if dist.get_world_size() == 1 and (fsdp_config is not None or tp_config is not None):
if dist.get_world_size(
) == 1 and (fsdp_config is not None or tp_config is not None):
parallelism = ''
if fsdp_config is not None:
parallelism += 'FSDP'
Expand Down Expand Up @@ -524,7 +526,7 @@ def train(cfg: DictConfig) -> Trainer:
tp_config['layer_plan'] |= strategy_layer_plan

# Parallelism config
parallelism_config = dict(fsdp=fsdp_config, tp=tp_config)
parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config}

# Optimizer
optimizer_name: str = train_cfg.optimizer.pop('name')
Expand Down
1 change: 0 additions & 1 deletion llmfoundry/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
models.register('fmapi_chat', func=FMAPIChatAPIEvalWrapper)
tp_strategy.register('ffn', func=ffn_tp_strategy)


__all__ = [
'ComposerHFCausalLM',
'ComposerHFT5',
Expand Down
34 changes: 19 additions & 15 deletions llmfoundry/models/utils/tp_strategy.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,43 @@
from typing import Optional

# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.models import ComposerModel
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,)
from torch.distributed.tensor.parallel.style import ParallelStyle


def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]:

TP_LAYERS = set(['up_proj', 'down_proj'])
TP_LAYERS = {'up_proj', 'down_proj'}

# validate that all TP_LAYERS are in model
tp_layers_in_model = set([layer for layer in TP_LAYERS for name, _ in model.named_modules() if layer in name])
tp_layers_in_model = set([
layer for layer in TP_LAYERS for name, _ in model.named_modules()
if layer in name
])
assert tp_layers_in_model == TP_LAYERS, f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.'

# generate layer plan
layer_plan: dict[str, ParallelStyle] = {}
for name, _ in model.named_modules():
if name.split('.')[-2:] == ['ffn', 'up_proj']:
layer_plan[name] = ColwiseParallel(
input_layouts = Replicate(),
output_layouts = Shard(-1),
input_layouts=Replicate(),
output_layouts=Shard(-1),
)
elif name.split('.')[-2:] == ['ffn', 'down_proj']:
layer_plan[name] = RowwiseParallel(
input_layouts = Shard(-1),
output_layouts = Shard(0),
input_layouts=Shard(-1),
output_layouts=Shard(0),
)
elif name.split('.')[-1] == 'ffn':
layer_plan[name] = PrepareModuleInput(
input_layouts = Shard(0),
desired_input_layouts = Replicate(),
use_local_output = True,
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
)

return layer_plan


1 change: 0 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Iterable,
Optional,
Union,
Callable
)

import torch
Expand Down
11 changes: 8 additions & 3 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,10 @@ def update_batch_size_info(cfg: dict[str, Any]) -> dict[str, Any]:
return cfg


def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict], tp_config: Optional[dict]):
def process_init_device(
model_cfg: dict[str, Any], fsdp_config: Optional[dict],
tp_config: Optional[dict]
):
# Restrict model init_device to 'meta' and 'cpu',
# using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors
# when multiple GPUs are available.
Expand Down Expand Up @@ -534,11 +537,13 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict],
# Set defaults for mixed initialization
fsdp_config.setdefault('load_monolith_rank0_only', True)

if tp_config is not None and 'ffn_config' in model_cfg and model_cfg['ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
if tp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
raise ValueError('Cannot use TP with MoEs.')

# Set ffn_config.device_mesh using fsdp_config
if fsdp_config is not None and 'ffn_config' in model_cfg and model_cfg['ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
if fsdp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:

shard_degree = fsdp_config.get('data_parallel_shard_degree', None)
replicate_degree = fsdp_config.get(
Expand Down
1 change: 0 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from llmfoundry.command_utils import train_from_yaml


if __name__ == '__main__':
yaml_path, args_list = sys.argv[1], sys.argv[2:]
train_from_yaml(yaml_path, args_list)
48 changes: 31 additions & 17 deletions tests/models/utils/test_tp_strategy.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,)

from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from llmfoundry.utils.builders import build_tp_strategy


def test_ffn_tp_strategy_layer_plan():

# Actual layer plan
tp_config = {
'strategy': 'ffn',
}
}

model_cfg = {
'name': 'mpt_causal_lm',
Expand All @@ -29,31 +33,41 @@ def test_ffn_tp_strategy_layer_plan():

# Expected layer plan
_expected_layer_plan = {
'ffn': PrepareModuleInput(
input_layouts = Shard(0),
desired_input_layouts = Replicate(),
use_local_output = True,
'ffn':
PrepareModuleInput(
input_layouts=Shard(0),
desired_input_layouts=Replicate(),
use_local_output=True,
),
'ffn.down_proj': RowwiseParallel(
input_layouts = Shard(-1),
output_layouts = Shard(0),
'ffn.down_proj':
RowwiseParallel(
input_layouts=Shard(-1),
output_layouts=Shard(0),
),
'ffn.up_proj': ColwiseParallel(
input_layouts = Replicate(),
output_layouts = Shard(-1),
)
'ffn.up_proj':
ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(-1),
),
}
expected_layer_plan = {
f'model.transformer.blocks.{layer_idx}.{name}': layer_plan
for name, layer_plan in _expected_layer_plan.items()
for layer_idx in range(model_cfg['n_layers'])
}
expected_layer_plan = {f'model.transformer.blocks.{layer_idx}.{name}': layer_plan for name, layer_plan in _expected_layer_plan.items() for layer_idx in range(model_cfg['n_layers'])}

# Compare expected and actual layer plans
for (n1, lp1), (n2, lp2) in zip(sorted(expected_layer_plan.items()), sorted(layer_plan.items())):
for (n1, lp1), (n2, lp2) in zip(
sorted(expected_layer_plan.items()), sorted(layer_plan.items())
):
assert n1 == n2
assert type(lp1) == type(lp2)
if isinstance(lp1, PrepareModuleInput):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.desired_input_layouts == lp2.desired_input_layouts
assert lp1.use_local_output == lp2.use_local_output
elif isinstance(lp1, ColwiseParallel) or isinstance(lp1, RowwiseParallel):
elif isinstance(lp1,
ColwiseParallel) or isinstance(lp1, RowwiseParallel):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.output_layouts == lp2.output_layouts
assert lp1.use_local_output == lp2.use_local_output

0 comments on commit 6d65a29

Please sign in to comment.