From 27d3d2c32319e75caa87b0a7860d64cd556cc26d Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:03:52 -0400 Subject: [PATCH] remove weight parallelism (#137) * remove weight parallelism * fix linting * remove parallel forward from mlp * remove weight parallel * cleanup --- megablocks/layers/arguments.py | 2 - megablocks/layers/glu.py | 3 - megablocks/layers/mlp.py | 50 +--- megablocks/layers/mpu.py | 8 - megablocks/layers/weight_parallel.py | 416 --------------------------- tests/layers/parallelism_test.py | 153 ---------- 6 files changed, 4 insertions(+), 628 deletions(-) delete mode 100644 megablocks/layers/weight_parallel.py delete mode 100644 tests/layers/parallelism_test.py diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index efe131d..ddbe2b7 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -40,8 +40,6 @@ class Arguments: # Parallelism arguments. moe_expert_model_parallelism: bool = False expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None - moe_weight_parallelism: bool = False - weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None pipeline_model_parallel_size: int = 1 num_layers_per_virtual_pipeline_stage: Optional[int] = None diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index fa888a6..4654576 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -44,9 +44,6 @@ def __init__(self, args: Arguments): self._should_set_parallelism_attribute, ) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GLU.',) - def forward(self, x, topo): if self.args.memory_optimized_mlp: raise NotImplementedError( diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 1cae4fb..f7cb782 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -9,7 +9,6 @@ from megablocks import grouped_gemm_util as gg from megablocks.layers import common, gelu, mpu -from megablocks.layers import weight_parallel as wp from megablocks.layers.activation_fn import act_fn from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn @@ -180,21 +179,7 @@ def create_dmoe_expert_weights( columns, init_method, ) - weights = weights.view([-1, columns]) - rows, columns = weights.shape - - if not args.moe_weight_parallelism: - return weights - - # Caclculate the number of rows on this weight parallel partition. - # 'rows' must be divisible by weight parallel world size. - weight_parallel_world_size = mpu.get_weight_parallel_world_size(args) - assert (rows % weight_parallel_world_size) == 0 - num_rows_per_rank = rows // weight_parallel_world_size - rank = mpu.get_weight_parallel_rank(args) - start_row = rank * num_rows_per_rank - end_row = (rank + 1) * num_rows_per_rank - return weights[start_row:end_row] + return weights.view([-1, columns]) class MemoryOptimizedMLP(torch.autograd.Function): @@ -323,8 +308,7 @@ class SparseMLP(torch.nn.Module): def __init__(self, args: Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args)) + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) self.w1 = torch.nn.Parameter( torch.empty( @@ -371,7 +355,7 @@ def __init__(self, args: Arguments): ), ) - self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism mpu.set_expert_model_parallel_attributes( self.w1, self._should_set_parallelism_attribute, @@ -390,33 +374,10 @@ def scale_grad(self, w): return w return scale_gradient(w, self.gradient_scale) - def parallel_forward(self, x, topo): - group = self.args.weight_parallel_group - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - if self.args.memory_optimized_mlp: - if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: - raise NotImplementedError( - f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.', - ) - return wp.memory_optimized_weight_parallel_mlp( - x, - w1, - w2, - topo, - group, - ) - - # Compute the MLP. - x = wp.sdd_nt(x, w1, topo, group) - activation_fn_out = act_fn(x, self.args.activation_fn) - return wp.dsd_nn(activation_fn_out, w2, group) - def forward(self, x, topo): w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.moe_weight_parallelism: - return self.parallel_forward(x, topo) - elif self.args.memory_optimized_mlp: + if self.args.memory_optimized_mlp: return memory_optimized_mlp( x, w1, @@ -542,9 +503,6 @@ def forward(self, x, tokens_per_expert): w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',) - if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( x, diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 6aa0015..239f75f 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -42,14 +42,6 @@ def copy_expert_model_parallel_attributes( ) -def get_weight_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1) - - -def get_weight_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0) - - def synchronized_print(group, *x): world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py deleted file mode 100644 index 82effec..0000000 --- a/megablocks/layers/weight_parallel.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import stk -import torch - -from megablocks.layers import gelu - - -def _gather_weights(w, group, parallel_w=None, async_op=False): - """Gather the weights across the process group. - - Args: - w: torch.Tensor, local shard of the weights. - group: ProcessGroup, the group to gather across. - parallel_w: torch.Tensor, option output tensor to use - for the gather. - async_op: Whether to gather asynchronously. - - Returns: - The gathered weights tensor and a handle for asynchronous - communication. - """ - n, k = w.shape - world_size = torch.distributed.get_world_size(group) - - if parallel_w is None: - parallel_w = torch.empty( - n * world_size, - k, - device=w.device, - dtype=w.dtype, - ) - handle = torch.distributed.all_gather_into_tensor( - parallel_w, - w, - group=group, - async_op=async_op, - ) - return parallel_w, handle - - -def _scaled_reduce_scatter(parallel_dw, group, dw=None, async_op=False): - """Scatter reduce the weights across the process group. - - Args: - parallel_dw: torch.Tensor, local shard of the weights. - group: ProcessGroup, the group to scatter-reduce across. - dw: torch.Tensor, option output tensor to use for the op. - async_op: Whether to scatter reduce asynchronously. - - Returns: - The reduced weights tensor, scaled by 1 / world_size, and - a handle for asynchronous communication. - """ - n, k = parallel_dw.shape - world_size = torch.distributed.get_world_size(group) - assert (n % world_size) == 0 - - # Pre-scale the gradients by the world size. - # - # NOTE: Reduce in float32, always. - parallel_dw = parallel_dw.float() / world_size - - if dw is None: - dw = torch.empty( - n // world_size, - k, - device=parallel_dw.device, - dtype=torch.float32, - ) - handle = torch.distributed.reduce_scatter_tensor( - dw, - parallel_dw, - group=group, - async_op=async_op, - ) - return dw, handle - - -class WeightParallelSddNt(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, x, w, topo, group): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w = w.to(ctx._dtype) - # [m, k] x [n, k] = [m, n] - if not x.is_contiguous() or not w.is_contiguous(): - raise ValueError("Expected contiguous 'x' and 'w'.") - - ctx.group = group - ctx.shape = topo.shape - ctx.save_for_backward( - x, - w, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # TODO(tgale): Support prefetching forward weights. - parallel_w, _ = _gather_weights(w, group) - return stk.ops.sdd(x, parallel_w.t(), topo).data - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - x, w = ctx.saved_tensors[:2] - grad = stk.Matrix(ctx.shape, grad, *ctx.saved_tensors[2:]) - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation. - parallel_w, handle = _gather_weights(w, ctx.group, async_op=True) - parallel_dw = None - if ctx.needs_input_grad[1]: - parallel_dw = stk.ops.dsd(grad.t(), x) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw, handle = _scaled_reduce_scatter( - parallel_dw, - ctx.group, - async_op=True, - ) - dx = None - if ctx.needs_input_grad[0]: - dx = stk.ops.dsd(grad, parallel_w) - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw = dw.to(w.dtype) - return dx, dw, None, None - - -def sdd_nt(a, b, topo, group): - return stk.Matrix( - topo.size(), - WeightParallelSddNt.apply(a, b, topo, group), - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - -class WeightParallelDsdNn(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward( - ctx, - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w, - group, - ): - # [m, k] x [k, n] = [m, n] - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - data = data.to(ctx._dtype) - w = w.to(ctx._dtype) - if not data.is_contiguous() or not w.is_contiguous(): - raise ValueError("Expected contiguous 'data' and 'w'.") - - ctx.group = group - ctx.shape = shape - ctx.save_for_backward( - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w, - ) - x = stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - # TODO(tgale): Support prefetching forward weights. - parallel_w, _ = _gather_weights(w, group) - return stk.ops.dsd(x, parallel_w) - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - x = stk.Matrix(ctx.shape, *ctx.saved_tensors[:-1]) - w = ctx.saved_tensors[-1] - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation. - parallel_w, handle = _gather_weights(w, ctx.group, async_op=True) - parallel_dw = None - if ctx.needs_input_grad[-2]: - parallel_dw = stk.ops.dsd(x.t(), grad) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw, handle = _scaled_reduce_scatter( - parallel_dw, - ctx.group, - async_op=True, - ) - dx = None - if ctx.needs_input_grad[1]: - dx = stk.ops.sdd(grad, parallel_w.t(), x) - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw = dw.to(w.dtype) - return None, dx.data, None, None, None, None, None, None, dw, None - - -def dsd_nn(a, b, group): - return WeightParallelDsdNn.apply( - a.size(), - a.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t, - b, - group, - ) - - -class MemoryOptimizedWeightParallelMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, x, w1, w2, topo, group): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - parallel_w1, _ = _gather_weights(w1, group) - sdd_out = stk.ops.sdd(x, parallel_w1.t(), topo) - - # GeLU. - gelu_out = gelu.gelu(sdd_out) - - # Layer 1: x @ w2. - # - # NOTE: Reuse the buffer for the w1 weight gather. - parallel_w2, _ = _gather_weights(w2, group, parallel_w1) - dsd_out = stk.ops.dsd(gelu_out, parallel_w2) - - # NOTE: Save the input to the layer and the gelu input for - # gradient computation. We'll re-compute the gelu forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.group = group - ctx.shape = topo.shape - ctx.save_for_backward( - x, - w1, - w2, - sdd_out.data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - return dsd_out - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, ddsd_out): - x, w1, w2 = ctx.saved_tensors[:3] - sdd_out = stk.Matrix(ctx.shape, *ctx.saved_tensors[3:]) - - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation and gelu recompute. - parallel_w2, handle = _gather_weights(w2, ctx.group, async_op=True) - - # Compute dw2 with recomputed gelu output. - gelu_out = gelu.gelu(sdd_out) - parallel_dw2 = stk.ops.dsd(gelu_out.t(), ddsd_out) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw2, handle = _scaled_reduce_scatter( - parallel_dw2, - ctx.group, - async_op=True, - ) - - # Compute dgelu_out. - # - # NOTE: We reuse the gelu_out allocation. - stk.backend.triton_kernels.sdd( - ddsd_out, - parallel_w2.t(), - sdd_out.shape, - gelu_out.data, - sdd_out.offsets, - sdd_out.row_indices, - sdd_out.column_indices, - ) - dgelu_out = gelu_out - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw2 = dw2.to(w2.dtype) - - # Start the weight gather asynchronously to overlap with the - # weight and gelu gradient computation. - # - # NOTE: Reuse the buffer from the w2 weight gather. - parallel_w1, handle = _gather_weights( - w1, - ctx.group, - parallel_w2, - async_op=True, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dgelu_out allocation. - dsdd_out = gelu.gelu_backward_(dgelu_out, sdd_out) - - # Compute dw1. - # - # NOTE: This reuses the parallel_dw2 allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.t().shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - True, # transpose_a - x, - parallel_dw2, - ) - parallel_dw1 = parallel_dw2 - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw1, handle = _scaled_reduce_scatter( - parallel_dw1, - ctx.group, - async_op=True, - ) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - parallel_w1, - ddsd_out, - ) - dx = ddsd_out - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw1 = dw1.to(w1.dtype) - return dx, dw1, dw2, None, None - - -memory_optimized_weight_parallel_mlp = MemoryOptimizedWeightParallelMLP.apply diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py deleted file mode 100644 index 35e40a0..0000000 --- a/tests/layers/parallelism_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import functools - -import numpy as np -import pytest -import torch - -from megablocks.layers import arguments, dmoe, mpu - -_PARALLELISM_TESTS = ( - (64, 1024, 512, 2048, 64, 1, False), - (64, 1024, 512, 2048, 64, 1, True), - # Test with fewer experts than ranks to verify tensor - # sharding in tandem with expert sharding. - (4, 1, 512, 2048, 4, 1, False), - (4, 1, 512, 2048, 4, 1, True), -) - - -# Todo: Fix this long term -@pytest.fixture -def group(): - return None - - -@pytest.mark.world_size(2) -@pytest.mark.gpu -@pytest.mark.parametrize(( - 'batch_size', - 'sequence_length', - 'hidden_size', - 'ffn_hidden_size', - 'num_experts', - 'top_k', - 'memory_optimized', -), _PARALLELISM_TESTS) -def test_expert_parallel_versus_weight_parallel( - group, - batch_size: int, - sequence_length: int, - hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - top_k: int, - memory_optimized: bool, -): - - init_fn = functools.partial(torch.nn.init.normal_, mean=0.0, std=0.1) - ep_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - wp_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_weight_parallelism=True, - weight_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - - # NOTE: Reset the seed so that the models get identical weights. - torch.manual_seed(1234) - ep = dmoe.dMoE(ep_args) - torch.manual_seed(1234) - wp = dmoe.dMoE(wp_args) - - # NOTE: Include the rank in the seed so we get different data per rank. - rank = torch.distributed.get_rank(group) - torch.manual_seed(1234 * rank) - x = torch.randn((batch_size, sequence_length, hidden_size), device=torch.cuda.current_device(), - dtype=torch.float32).requires_grad_(True) - - # Test forward. - out, _ = wp(x) - expected_out, _ = ep(x) - - # Check the forward outputs. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - out.detach().float().cpu(), - expected_out.detach().float().cpu(), - rtol=1e-4, - atol=1e-4, - ) is None - - # Test backward. - out.mean().backward() - expected_out.mean().backward() - - # NOTE: If tensor parallelism is used different weights can be on - # different ranks. Gather the full grads to rank 0 to compare. - def gather(x): - m, n = x.shape - world_size = torch.distributed.get_world_size(group) - out = torch.empty(m * world_size, n, device=x.device, dtype=x.dtype) - torch.distributed.all_gather_into_tensor(out, x, group=group) - return out - - def permute(x): - esd = mpu.expert_sharding_degree(ep_args) - hsd = mpu.hidden_sharding_degree(ep_args) - out = x.view(hsd, esd, -1).transpose(1, 0).contiguous() - return out.view(num_experts * ffn_hidden_size, hidden_size) - - wp_w2_grad = gather(wp.experts.mlp.w2.grad) - ep_w2_grad = permute(gather(ep.experts.mlp.w2.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w2_grad.float().cpu(), - ep_w2_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - wp_w1_grad = gather(wp.experts.mlp.w1.grad) - ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w1_grad.float().cpu(), - ep_w1_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - # Verify the router weight gradient, which is not sharded. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - wp.router.layer.weight.grad.float().cpu(), - ep.router.layer.weight.grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None