Skip to content

Commit

Permalink
tp-strat does not crash
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 25, 2024
1 parent 76adc48 commit 90264ea
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
8 changes: 6 additions & 2 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
TraceHandler,
cyclic_schedule,
)
from composer.utils import dist, get_device, reproducibility, ParallelismConfig, TPConfig
from composer.utils import dist, get_device, reproducibility, ParallelismConfig, TPConfig, FSDPConfig
from icecream import install
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

Expand Down Expand Up @@ -64,6 +65,7 @@

log = logging.getLogger(__name__)

install()

def validate_config(train_config: TrainConfig):
"""Validates compatible model and dataloader selection."""
Expand Down Expand Up @@ -524,7 +526,9 @@ def train(cfg: DictConfig) -> Trainer:
tp_config['layer_plan'] |= strategy_layer_plan

# Parallelism config
parallelism_config: ParallelismConfig = {'fsdp': fsdp_config, 'tp': tp_config}
tp = TPConfig(**tp_config)
fsdp = FSDPConfig(**fsdp_config)
parallelism_config = ParallelismConfig(fsdp=fsdp, tp=tp)

# Optimizer
optimizer_name: str = train_cfg.optimizer.pop('name')
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
models.register('fmapi_chat', func=FMAPIChatAPIEvalWrapper)
tp_strategy.register('ffn', func=ffn_tp_strategy)


__all__ = [
'ComposerHFCausalLM',
'ComposerHFT5',
Expand Down
19 changes: 15 additions & 4 deletions llmfoundry/models/utils/tp_strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional

from composer.models import ComposerModel
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.distributed._tensor import Replicate, Shard
from torch.distributed._tensor import Replicate, Shard, Placement


def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]:
Expand All @@ -15,15 +17,24 @@ def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]:
# generate layer plan
layer_plan: dict[str, ParallelStyle] = {}
for name, _ in model.named_modules():
if 'up_proj' in name:
ic(name)
if name.split('.')[-2:] == ['ffn', 'up_proj']:
layer_plan[name] = ColwiseParallel(
input_layouts = Replicate(),
output_layouts = Shard(-1),
)
elif 'down_proj' in name:
elif name.split('.')[-2:] == ['ffn', 'down_proj']:
layer_plan[name] = RowwiseParallel(
input_layouts = Shard(-1),
output_layouts = Replicate(),
)
elif name.split('.')[-1] == 'ffn':
layer_plan[name] = PrepareModuleInput(
input_layouts = Shard(0),
desired_input_layouts = Replicate(),
use_local_output = True,
)

return layer_plan


6 changes: 3 additions & 3 deletions scripts/train/yamls/pretrain/mpt-125m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ algorithms:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 4800ba # ~ 2.5B tokens
max_duration: 100ba
eval_interval: 500ba
eval_first: false
eval_subset_num_batches: -1
eval_subset_num_batches: 0
global_train_batch_size: 256

# System
Expand All @@ -99,7 +99,7 @@ fsdp_config:

# TP
tp_config:
strategy: megatron
strategy: ffn
tensor_parallel_degree: 2

# Logging
Expand Down

0 comments on commit 90264ea

Please sign in to comment.