Skip to content

Commit

Permalink
feat: add tests for cli usage of TP and plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Dec 13, 2024
1 parent 780ae7b commit 74bf3e2
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 13 deletions.
12 changes: 8 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
save_fsdp_optimizer,
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME, BETA_TP_AVAILABLE_PYTORCH_VERSION
from .utils.constants import BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module

Expand Down Expand Up @@ -349,7 +349,9 @@ def __init__(
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
torch_tp_plugin, TorchTensorParallelPlugin
):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")

Expand All @@ -363,12 +365,14 @@ def __init__(
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided

if torch_tp_plugin is None:
torch_tp_plugin = (TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None)
torch_tp_plugin = (
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
)
else:
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
os.environ["ACCELERATE_USE_TP"] = "true"

if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def launch_command_parser(subparsers=None):
type=str,
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
)

# tp args
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
tp_args.add_argument(
Expand Down
18 changes: 10 additions & 8 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,11 +740,11 @@ def __init__(
self.iteration = 0

# if a device mesh is provided extract each dimension (dp, fsdp, tp)
# device mesh may hold any number of dimensions, however,
# device mesh may hold any number of dimensions, however,
# below code is for targetted support for dp, fsdp and tp
# device mesh will be used only if there is tp involved
# or any multi-dimensional parallelism involving tp

# device mesh will be used only if there is tp involved
# or any multi-dimensional parallelism involving tp
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
# otherwise the default behavour not using device mesh should be sufficient
# since multi dimensional parallelism devoid of tp would anyway need
Expand Down Expand Up @@ -777,8 +777,10 @@ def _fetch_batches(self, iterator):
if self.split_batches:
# One batch of the main iterator is dispatched and split.
if self.submesh_tp:
logger.warning("Use of split_batches for TP would need the dataloader to produce duplicate batches,"
"otherwise, use dispatch_batches=True instead.")
logger.warning(
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
"otherwise, use dispatch_batches=True instead."
)
self._update_state_dict()
batch = next(iterator)
else:
Expand Down Expand Up @@ -1078,7 +1080,7 @@ def prepare_data_loader(
state = PartialState()
if num_processes is None:
num_processes = state.num_processes

# when device mesh is used, specifically with TP
# then there is need to update process_index and num_processes
# to bring in the effect of generating same batch across TP ranks
Expand All @@ -1098,7 +1100,7 @@ def prepare_data_loader(
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
num_processes = (submesh_fsdp_size * submesh_dp_size)
num_processes = submesh_fsdp_size * submesh_dp_size
if process_index is None:
process_index = state.process_index
if torch_device_mesh:
Expand Down
7 changes: 7 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ def require_fsdp(test_case):
return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case)


def require_tp(test_case):
"""
Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed
"""
return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case)


def require_torch_min_version(test_case=None, version=None):
"""
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,10 @@ class TorchTensorParallelPlugin:
torch_device_mesh: torch.distributed.DeviceMesh = field(default=None)

def __post_init__(self):
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", 1) == 1 else os.environ.get("TP_SIZE", 1)
if self.tp_size == 1:
raise ValueError("Provide TP degree > 1.")

if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
Expand Down
78 changes: 78 additions & 0 deletions tests/tp/test_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from transformers.testing_utils import mockenv_context
from transformers.trainer_utils import set_seed

from accelerate.test_utils.testing import (
AccelerateTestCase,
TempDirTestCase,
execute_subprocess_async,
get_launch_command,
path_in_accelerate_package,
require_multi_device,
require_non_cpu,
require_non_torch_xla,
require_tp,
slow,
)
from accelerate.utils import patch_environment
from accelerate.utils.dataclasses import TorchTensorParallelPlugin


set_seed(42)


@require_tp
@require_non_cpu
@require_non_torch_xla
class TPPluginIntegration(AccelerateTestCase):
def setUp(self):
super().setUp()

self.dist_env = dict(
MASTER_ADDR="localhost",
MASTER_PORT="10999",
RANK="0",
LOCAL_RANK="0",
WORLD_SIZE="1",
)

self.tp_env = dict(ACCELERATE_USE_TP="true", TP_SIZE="2", **self.dist_env)

def test_device_mesh_init(self):
with mockenv_context(**self.tp_env):
tp_plugin = TorchTensorParallelPlugin()
assert str(tp_plugin.torch_device_mesh["tp"].size()) == self.tp_env["TP_SIZE"]


@require_non_torch_xla
@require_tp
@require_multi_device
@slow
class TPIntegrationTest(TempDirTestCase):
test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps")

def setUp(self):
super().setUp()
self.test_tp_size = 2

def test_working_of_tp(self):
self.test_file_path = self.test_scripts_folder / "test_performance.py"
cmd = get_launch_command(
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
)
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)

0 comments on commit 74bf3e2

Please sign in to comment.