Skip to content

Commit

Permalink
test passes
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 30, 2024
1 parent 3844c78 commit 8be472f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def create_c4_dataset_xxsmall(path: Path) -> str:
shutil.copytree(
os.path.join(c4_dir, 'val_xxsmall'),
os.path.join(c4_dir, mocked_split),
dirs_exist_ok=True,
)
assert os.path.exists(c4_dir)
return c4_dir
Expand Down
19 changes: 17 additions & 2 deletions tests/tp/test_tp_strategies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
from pathlib import Path
from shutil import rmtree
from tempfile import TemporaryDirectory

import pytest
from composer.utils import dist
from icecream import install
from omegaconf import OmegaConf as om
from torch.distributed._tensor import Replicate, Shard
Expand Down Expand Up @@ -104,17 +107,25 @@ def test_ffn_tp_strategy():
@pytest.mark.parametrize('tp_strategy', ['ffn'])
def test_tp_train(tp_strategy: str):
"""Test that we can train with FSDP-TP."""
with TemporaryDirectory() as tmp_path:
data_dir = '/my-data-dir/'
try:
# Make `train_cfg`` with a tensor parallelism strategy
dataset_name = create_c4_dataset_xxsmall(Path(tmp_path))
dataset_name = create_c4_dataset_xxsmall(Path(data_dir))
train_cfg = gpt_tiny_cfg(dataset_name, 'gpu')
train_cfg.variables.run_name = 'tp-test'
train_cfg.tp_config = {
'strategy': tp_strategy,
'tensor_parallel_degree': 2,
}

# Train
train(train_cfg)
except Exception as e:
raise e
finally:
# always remove data directory
if os.path.isdir(data_dir):
rmtree(data_dir)


@pytest.mark.gpu
Expand Down Expand Up @@ -152,3 +163,7 @@ def test_tp_train_with_moes():
match='Tensor Parallelism is not currently supported for MoE models.',
):
process_init_device(model_cfg, fsdp_cfg, tp_cfg)


if __name__ == '__main__':
test_tp_train('ffn')

0 comments on commit 8be472f

Please sign in to comment.