Skip to content

Commit

Permalink
feat: add tests for cli usage of TP and plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Dec 13, 2024
1 parent 780ae7b commit e08c364
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ def require_fsdp(test_case):
"""
return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case)

def require_tp(test_case):
"""
Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed
"""
return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case)


def require_torch_min_version(test_case=None, version=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,10 @@ class TorchTensorParallelPlugin:
torch_device_mesh: torch.distributed.DeviceMesh = field(default=None)

def __post_init__(self):
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", 1) == 1 else os.environ.get("TP_SIZE", 1)
if self.tp_size == 1:
raise ValueError("Provide TP degree > 1.")

if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
Expand Down
77 changes: 77 additions & 0 deletions tests/tp/test_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from transformers.testing_utils import mockenv_context
from transformers.trainer_utils import set_seed


from accelerate.test_utils.testing import (
AccelerateTestCase,
TempDirTestCase,
execute_subprocess_async,
get_launch_command,
path_in_accelerate_package,
require_tp,
require_multi_device,
require_non_cpu,
require_non_torch_xla,
slow,
)
from accelerate.utils import patch_environment
from accelerate.utils.dataclasses import TorchTensorParallelPlugin


set_seed(42)


@require_tp
@require_non_cpu
@require_non_torch_xla
class TPPluginIntegration(AccelerateTestCase):
def setUp(self):
super().setUp()

self.dist_env = dict(
MASTER_ADDR="localhost",
MASTER_PORT="10999",
RANK="0",
LOCAL_RANK="0",
WORLD_SIZE="1",
)

self.tp_env = dict(ACCELERATE_USE_TP="true", TP_SIZE=2, **self.dist_env)

def test_device_mesh_init(self):
with mockenv_context(**self.tp_env):
tp_plugin = TorchTensorParallelPlugin()
assert tp_plugin.torch_device_mesh["tp"].size() == self.tp_env["TP_SIZE"]


@require_non_torch_xla
@require_tp
@require_multi_device
@slow
class TPIntegrationTest(TempDirTestCase):
test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps")

def setUp(self):
super().setUp()
self.test_tp_size = 2

def test_working_of_tp(self):
self.test_file_path = self.test_scripts_folder / "test_performance.py"
cmd = get_launch_command(num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size)
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)

0 comments on commit e08c364

Please sign in to comment.