Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 27, 2024
1 parent ee45600 commit 4cdaf08
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import time
import warnings
from copy import deepcopy
from typing import Any, Optional, Union

import torch
Expand Down Expand Up @@ -351,7 +350,7 @@ def train(cfg: DictConfig) -> Trainer:
# Initialize context
init_context = process_init_device(model_config, fsdp_config, tp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)
logged_cfg.update({'tp_config': deepcopy(tp_config)}, merge=True)
logged_cfg.update({'tp_config': tp_config}, merge=True)

# Build tokenizer
log.info('Building tokenizer...')
Expand Down Expand Up @@ -517,6 +516,7 @@ def train(cfg: DictConfig) -> Trainer:

# TP config
if tp_config is not None:

strategy = tp_config.pop('strategy', None)
assert isinstance(strategy, str), '`strategy` must be in `tp_config`.'
tp_config['layer_plan'] = build_tp_strategies(strategy, model)
Expand Down
15 changes: 15 additions & 0 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,21 @@ def process_init_device(
# Set defaults for mixed initialization
fsdp_config.setdefault('load_monolith_rank0_only', True)

if tp_config is not None:
# Check tp_config has required fields
if 'strategy' not in tp_config or 'tensor_parallel_degree' not in tp_config:
raise ValueError(
"`tp_config` requires 'strategy' and 'tensor_parallel_degree' values. "
)

# Check we are not using tensor parallelism with MoEs
if 'ffn_config' in model_cfg and model_cfg['ffn_config'].get(
'ffn_type', None
) in ffns_with_megablocks:
raise ValueError(
'Tensor Parallelism is not currently supported for MoE models.',
)

# Check we are not using tensor parallelism with MoEs
if tp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
Expand Down
27 changes: 24 additions & 3 deletions tests/tp/test_tp_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tempfile import TemporaryDirectory

import pytest
from icecream import install
from omegaconf import OmegaConf as om
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
Expand All @@ -19,6 +20,8 @@
from llmfoundry.utils.config_utils import process_init_device
from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg

install()


@pytest.mark.gpu
@pytest.mark.filterwarnings(
Expand Down Expand Up @@ -97,8 +100,26 @@ def test_ffn_tp_strategy():


@pytest.mark.gpu
def test_no_tp_with_one_gpu():
"""Test that when we have one GPU, we use DDP and not FSDP-TP."""
@pytest.mark.world_size(4)
@pytest.mark.parametrize('tp_strategy', ['ffn'])
def test_tp_train(tp_strategy: str):
"""Test that we can train with FSDP-TP."""
with TemporaryDirectory() as tmp_path:
# Make `train_cfg`` with a tensor parallelism strategy
dataset_name = create_c4_dataset_xxsmall(Path(tmp_path))
train_cfg = gpt_tiny_cfg(dataset_name, 'gpu')
train_cfg.tp_config = {
'strategy': tp_strategy,
'tensor_parallel_degree': 2,
}

# Train
train(train_cfg)


@pytest.mark.gpu
def test_tp_train_with_one_gpu():
"""Test that when we have one GPU, we train DDP and not FSDP-TP."""
with TemporaryDirectory() as tmp_path:
# Make `train_cfg`` with a tensor parallelism strategy
dataset_name = create_c4_dataset_xxsmall(Path(tmp_path))
Expand All @@ -115,7 +136,7 @@ def test_no_tp_with_one_gpu():


@pytest.mark.gpu # use gpu because `megablocks` only installed with `gpu` dependencies
def test_no_tp_with_moes():
def test_tp_train_with_moes():
"""Test that tensor parallelism is not compatible with MoEs."""
# Make `cfg` for MoE model, fsdp, and tp
train_cfg_path: str = 'scripts/train/yamls/pretrain/testing-moe.yaml'
Expand Down

0 comments on commit 4cdaf08

Please sign in to comment.