From da67cba7751b64b02a3511bbb099bd684b1169d1 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Tue, 15 Oct 2024 10:37:57 +0530 Subject: [PATCH] feat: support tensor parallel using Pytorch 2.0 Signed-off-by: Mehant Kammakomati --- src/accelerate/accelerator.py | 9 +++++++ src/accelerate/data_loader.py | 37 +++++++++++++++++++++++++++-- src/accelerate/state.py | 4 ++++ src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/constants.py | 2 +- src/accelerate/utils/dataclasses.py | 23 ++++++++++++++++++ 6 files changed, 73 insertions(+), 3 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index ab949c42e43..7d9953151b0 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -67,6 +67,7 @@ ProjectConfiguration, RNGType, TorchDynamoPlugin, + TorchTensorParallelPlugin, apply_fp8_autowrap, check_os_kernel, clean_state_dict_for_safetensors, @@ -188,6 +189,9 @@ class Accelerator: fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*): Tweak your FSDP related args using this argument. This argument is optional and can be configured directly using *accelerate config* + torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*): + Tweak your torch tensor parallel. This argument is optional and can be configured directly using + *accelerate config* megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*): Tweak your MegatronLM related args using this argument. This argument is optional and can be configured directly using *accelerate config* @@ -254,6 +258,7 @@ def __init__( dataloader_config: DataLoaderConfiguration | None = None, deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None, fsdp_plugin: FullyShardedDataParallelPlugin | None = None, + torch_tp_plugin: TorchTensorParallelPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None, rng_types: list[str | RNGType] | None = None, log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, @@ -418,6 +423,7 @@ def __init__( dynamo_plugin=dynamo_plugin, deepspeed_plugin=deepspeed_plugins, fsdp_plugin=fsdp_plugin, + torch_tp_plugin=torch_tp_plugin, megatron_lm_plugin=megatron_lm_plugin, _from_accelerator=True, **kwargs, @@ -1461,6 +1467,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ) if self.ddp_handler is not None: self.ddp_handler.register_comm_hook(model) + elif self.distributed_type == DistributedType.TP: + model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -2117,6 +2125,7 @@ def prepare_data_loader( data_seed=self.dataloader_config.data_seed, non_blocking=self.non_blocking, use_stateful_dataloader=self.use_stateful_dataloader, + torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index bf3f35fb7e8..851a796885c 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -713,6 +713,7 @@ def __init__( _drop_last: bool = False, _non_blocking: bool = False, slice_fn=None, + torch_device_mesh=None, **kwargs, ): shuffle = False @@ -732,15 +733,37 @@ def __init__( self._drop_last = _drop_last self._non_blocking = _non_blocking self.skip_batches = skip_batches + self.torch_device_mesh = torch_device_mesh self.slice_fn = slice_tensors if slice_fn is None else slice_fn self.iteration = 0 + # if a device mesh is provided extract each dimension (tp and dp) + # device mesh will be used only if there is tp involved + # otherwise the default behavour should be sufficient + self.submesh_tp = None + self.submesh_dp = None + if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names: + # extract torch sub device mesh objects + self.submesh_tp = self.torch_device_mesh["tp"] + if "dp" in self.torch_device_mesh.mesh_dim_names: + self.submesh_dp = self.torch_device_mesh["dp"] + if self.submesh_tp and self.submesh_dp: + raise ValueError("TP + DDP / TP + FSDP is not yet supported") + def _fetch_batches(self, iterator): batches, batch = None, None # On process 0, we gather the batch to dispatch. if self.state.process_index == 0: + # Procedure to support TP only is simpler + # since we want to dispatch the same batch of samples across all ranks + # this removes complexity of handling multiple tp rank groups when TP + DP + # combination is involved. + try: + # for TP case avoid using split_batches + # since it would mean that the dataloader should be spilling out + # duplicates of batches. if self.split_batches: # One batch of the main iterator is dispatched and split. self._update_state_dict() @@ -749,9 +772,15 @@ def _fetch_batches(self, iterator): # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] - for _ in range(self.state.num_processes): + if self.submesh_tp: + # when tp, extract single batch and then replicate self._update_state_dict() - batches.append(next(iterator)) + batch = next(iterator) + batches = [batch] * self.state.num_processes + else: + for _ in range(self.state.num_processes): + self._update_state_dict() + batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) except RuntimeError as e: @@ -942,6 +971,7 @@ def prepare_data_loader( data_seed: Optional[int] = None, non_blocking: bool = False, use_stateful_dataloader: bool = False, + torch_device_mesh: torch.distributed._tensor.DeviceMesh = None, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -1009,6 +1039,8 @@ def prepare_data_loader( "If set to true, the dataloader prepared by the Accelerator will be backed by " "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." + torch_device_mesh (`torch.distributed._tensor.DeviceMesh`, *optional*, defaults to `None`): + PyTorch device mesh. Returns: @@ -1144,6 +1176,7 @@ def prepare_data_loader( _non_blocking=non_blocking, slice_fn=slice_fn_for_dispatch, use_stateful_dataloader=use_stateful_dataloader, + torch_device_mesh=torch_device_mesh, **kwargs, ) elif sampler_is_batch_sampler: diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 47d718704a6..8d226ac305b 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -850,6 +850,7 @@ def __init__( dynamo_plugin=None, deepspeed_plugin=None, fsdp_plugin=None, + torch_tp_plugin=None, megatron_lm_plugin=None, _from_accelerator: bool = False, **kwargs, @@ -864,6 +865,7 @@ def __init__( if not self.initialized: self.deepspeed_plugins = None self.use_ipex = None + self.torch_tp_plugin = torch_tp_plugin mixed_precision = ( parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None @@ -921,6 +923,8 @@ def __init__( self.distributed_type = DistributedType.MEGATRON_LM megatron_lm_plugin.set_mixed_precision(self._mixed_precision) self.megatron_lm_plugin = megatron_lm_plugin + if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None: + self.distributed_type = DistributedType.TP elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]: if is_ipex_available(): # check if user disables it explicitly diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 5b8917fcd48..558cc2f4769 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -57,6 +57,7 @@ SageMakerDistributedType, TensorInformation, TorchDynamoPlugin, + TorchTensorParallelPlugin, add_model_config_to_megatron_parser, ) from .environment import ( diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index a6d7d262678..f0669797bbe 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -76,7 +76,7 @@ "master_port", ] -CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"] +CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"] TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [ "MULTI_NPU", "MULTI_MLU", diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 39e048a6039..165016835dc 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -540,6 +540,7 @@ class DistributedType(str, enum.Enum): MULTI_XPU = "MULTI_XPU" DEEPSPEED = "DEEPSPEED" FSDP = "FSDP" + TP = "TP" XLA = "XLA" MEGATRON_LM = "MEGATRON_LM" @@ -1810,6 +1811,28 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) +@dataclass +class TorchTensorParallelPlugin: + """ + This plugin is used to enable tensor parallelism using PyTorch >= 2.0. + """ + + tp_size: int = field( + default=1, + metadata={"help": "tensor parallel size will be used in the device mesh preparation"}, + ) + + # type has to be "torch.distributed.DeviceMesh" + torch_device_mesh: torch.distributed.DeviceMesh = field(default=None) + + def __post_init__(self): + from torch.distributed.device_mesh import init_device_mesh + + mesh_dim_name = "tp" + device = "cuda" # support for other devices has to be investigated + self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,)) + + @dataclass class MegatronLMPlugin: """