Skip to content

Commit

Permalink
feat: add dataloader for TP and n-dim parallel in non-dispatch mode
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Dec 13, 2024
1 parent 8ade23c commit 0991429
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 3 deletions.
11 changes: 11 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ProjectConfiguration,
RNGType,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
Expand Down Expand Up @@ -188,6 +189,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
*accelerate config*
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
Expand Down Expand Up @@ -254,6 +258,7 @@ def __init__(
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand Down Expand Up @@ -418,6 +423,7 @@ def __init__(
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
**kwargs,
Expand Down Expand Up @@ -1461,6 +1467,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not model.supports_tp_plan:
raise NotImplementedError("Provided model does not support tensor parallelism")
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -2117,6 +2127,7 @@ def prepare_data_loader(
data_seed=self.dataloader_config.data_seed,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
73 changes: 71 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def __init__(
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
torch_device_mesh=None,
**kwargs,
):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
Expand Down Expand Up @@ -713,6 +714,7 @@ def __init__(
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
torch_device_mesh=None,
**kwargs,
):
shuffle = False
Expand All @@ -732,26 +734,66 @@ def __init__(
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
self.torch_device_mesh = torch_device_mesh

self.slice_fn = slice_tensors if slice_fn is None else slice_fn
self.iteration = 0

# if a device mesh is provided extract each dimension (dp, fsdp, tp)
# device mesh may hold any number of dimensions, however,
# below code is for targetted support for dp, fsdp and tp

# device mesh will be used only if there is tp involved
# or any multi-dimensional parallelism involving tp
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
# otherwise the default behavour not using device mesh should be sufficient
# since multi dimensional parallelism devoid of tp would anyway need
# different batches for each process irrespective of dp or fsdp
self.submesh_tp = None
self.submesh_dp = None
self.submesh_fsdp = None
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_tp = self.torch_device_mesh["tp"]
if "dp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_dp = self.torch_device_mesh["dp"]
if "fsdp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_fsdp = self.torch_device_mesh["fsdp"]
if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")

def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
# Procedure to support TP only is simpler
# since we want to dispatch the same batch of samples across all ranks
# this removes complexity of handling multiple tp rank groups when TP + DP
# combination is involved.

try:
# for TP case avoid using split_batches
# since it would mean that the dataloader should be spilling out
# duplicates of batches.
if self.split_batches:
# One batch of the main iterator is dispatched and split.
if self.submesh_tp:
logger.warning("Use of split_batches for TP would need the dataloader to produce duplicate batches,"
"otherwise, use dispatch_batches=True instead.")
self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
if self.submesh_tp:
# when tp, extract single batch and then replicate
self._update_state_dict()
batches.append(next(iterator))
batch = next(iterator)
batches = [batch] * self.state.num_processes
else:
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
except RuntimeError as e:
Expand Down Expand Up @@ -942,6 +984,7 @@ def prepare_data_loader(
data_seed: Optional[int] = None,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh: torch.distributed.DeviceMesh = None,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -1009,6 +1052,8 @@ def prepare_data_loader(
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
PyTorch device mesh.
Returns:
Expand All @@ -1033,8 +1078,31 @@ def prepare_data_loader(
state = PartialState()
if num_processes is None:
num_processes = state.num_processes

# when device mesh is used, specifically with TP
# then there is need to update process_index and num_processes
# to bring in the effect of generating same batch across TP ranks
# and different batch across FSDP and DP ranks.
# Example:
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
# ranks would range from 0...11
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
# processes with same ranks/ids would receive the same batch
submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if torch_device_mesh:
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
num_processes = (submesh_fsdp_size * submesh_dp_size)
if process_index is None:
process_index = state.process_index
if torch_device_mesh:
process_index = process_index // submesh_tp_size

# Sanity check
if split_batches:
Expand Down Expand Up @@ -1144,6 +1212,7 @@ def prepare_data_loader(
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
torch_device_mesh=torch_device_mesh,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def __init__(
dynamo_plugin=None,
deepspeed_plugin=None,
fsdp_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
_from_accelerator: bool = False,
**kwargs,
Expand All @@ -864,6 +865,7 @@ def __init__(
if not self.initialized:
self.deepspeed_plugins = None
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
Expand Down Expand Up @@ -921,6 +923,8 @@ def __init__(
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
SageMakerDistributedType,
TensorInformation,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
add_model_config_to_megatron_parser,
)
from .environment import (
Expand Down
3 changes: 2 additions & 1 deletion src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"

STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

Expand Down Expand Up @@ -76,7 +77,7 @@
"master_port",
]

CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"]
CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
"MULTI_NPU",
"MULTI_MLU",
Expand Down
28 changes: 28 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch

from .constants import (
BETA_TP_AVAILABLE_PYTORCH_VERSION,
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
FSDP_SHARDING_STRATEGY,
Expand Down Expand Up @@ -540,6 +541,7 @@ class DistributedType(str, enum.Enum):
MULTI_XPU = "MULTI_XPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"
TP = "TP"
XLA = "XLA"
MEGATRON_LM = "MEGATRON_LM"

Expand Down Expand Up @@ -1810,6 +1812,32 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F
self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)


@dataclass
class TorchTensorParallelPlugin:
"""
This plugin is used to enable tensor parallelism using PyTorch >= 2.0.
"""

tp_size: int = field(
default=1,
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
)

# type has to be "torch.distributed.DeviceMesh"
torch_device_mesh: torch.distributed.DeviceMesh = field(default=None)

def __post_init__(self):
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
)
from torch.distributed.device_mesh import init_device_mesh

mesh_dim_name = "tp"
device = "cuda" # support for other devices has to be investigated
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))


@dataclass
class MegatronLMPlugin:
"""
Expand Down

0 comments on commit 0991429

Please sign in to comment.