From 3016e093c05def743ba858e642cfbba1fce44c4a Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 9 Aug 2024 18:04:57 +0000 Subject: [PATCH] remove weight parallelism --- megablocks/layers/arguments.py | 2 - megablocks/layers/glu.py | 3 - megablocks/layers/mlp.py | 29 +----- megablocks/layers/mpu.py | 9 -- tests/layers/parallelism_test.py | 153 ------------------------------- 5 files changed, 5 insertions(+), 191 deletions(-) delete mode 100644 tests/layers/parallelism_test.py diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index efe131d..ddbe2b7 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -40,8 +40,6 @@ class Arguments: # Parallelism arguments. moe_expert_model_parallelism: bool = False expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None - moe_weight_parallelism: bool = False - weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None pipeline_model_parallel_size: int = 1 num_layers_per_virtual_pipeline_stage: Optional[int] = None diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index fa888a6..4654576 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -44,9 +44,6 @@ def __init__(self, args: Arguments): self._should_set_parallelism_attribute, ) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GLU.',) - def forward(self, x, topo): if self.args.memory_optimized_mlp: raise NotImplementedError( diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 1cae4fb..361f9c9 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -181,20 +181,7 @@ def create_dmoe_expert_weights( init_method, ) weights = weights.view([-1, columns]) - rows, columns = weights.shape - - if not args.moe_weight_parallelism: - return weights - - # Caclculate the number of rows on this weight parallel partition. - # 'rows' must be divisible by weight parallel world size. - weight_parallel_world_size = mpu.get_weight_parallel_world_size(args) - assert (rows % weight_parallel_world_size) == 0 - num_rows_per_rank = rows // weight_parallel_world_size - rank = mpu.get_weight_parallel_rank(args) - start_row = rank * num_rows_per_rank - end_row = (rank + 1) * num_rows_per_rank - return weights[start_row:end_row] + return weights class MemoryOptimizedMLP(torch.autograd.Function): @@ -323,8 +310,7 @@ class SparseMLP(torch.nn.Module): def __init__(self, args: Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args)) + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) self.w1 = torch.nn.Parameter( torch.empty( @@ -371,7 +357,7 @@ def __init__(self, args: Arguments): ), ) - self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism mpu.set_expert_model_parallel_attributes( self.w1, self._should_set_parallelism_attribute, @@ -391,7 +377,7 @@ def scale_grad(self, w): return scale_gradient(w, self.gradient_scale) def parallel_forward(self, x, topo): - group = self.args.weight_parallel_group + group = None w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) if self.args.memory_optimized_mlp: if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: @@ -414,9 +400,7 @@ def parallel_forward(self, x, topo): def forward(self, x, topo): w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.moe_weight_parallelism: - return self.parallel_forward(x, topo) - elif self.args.memory_optimized_mlp: + if self.args.memory_optimized_mlp: return memory_optimized_mlp( x, w1, @@ -542,9 +526,6 @@ def forward(self, x, tokens_per_expert): w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',) - if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( x, diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 6aa0015..63de336 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -41,15 +41,6 @@ def copy_expert_model_parallel_attributes( getattr(source_tensor, 'expert_model_parallel'), ) - -def get_weight_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1) - - -def get_weight_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0) - - def synchronized_print(group, *x): world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py deleted file mode 100644 index 35e40a0..0000000 --- a/tests/layers/parallelism_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import functools - -import numpy as np -import pytest -import torch - -from megablocks.layers import arguments, dmoe, mpu - -_PARALLELISM_TESTS = ( - (64, 1024, 512, 2048, 64, 1, False), - (64, 1024, 512, 2048, 64, 1, True), - # Test with fewer experts than ranks to verify tensor - # sharding in tandem with expert sharding. - (4, 1, 512, 2048, 4, 1, False), - (4, 1, 512, 2048, 4, 1, True), -) - - -# Todo: Fix this long term -@pytest.fixture -def group(): - return None - - -@pytest.mark.world_size(2) -@pytest.mark.gpu -@pytest.mark.parametrize(( - 'batch_size', - 'sequence_length', - 'hidden_size', - 'ffn_hidden_size', - 'num_experts', - 'top_k', - 'memory_optimized', -), _PARALLELISM_TESTS) -def test_expert_parallel_versus_weight_parallel( - group, - batch_size: int, - sequence_length: int, - hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - top_k: int, - memory_optimized: bool, -): - - init_fn = functools.partial(torch.nn.init.normal_, mean=0.0, std=0.1) - ep_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - wp_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_weight_parallelism=True, - weight_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - - # NOTE: Reset the seed so that the models get identical weights. - torch.manual_seed(1234) - ep = dmoe.dMoE(ep_args) - torch.manual_seed(1234) - wp = dmoe.dMoE(wp_args) - - # NOTE: Include the rank in the seed so we get different data per rank. - rank = torch.distributed.get_rank(group) - torch.manual_seed(1234 * rank) - x = torch.randn((batch_size, sequence_length, hidden_size), device=torch.cuda.current_device(), - dtype=torch.float32).requires_grad_(True) - - # Test forward. - out, _ = wp(x) - expected_out, _ = ep(x) - - # Check the forward outputs. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - out.detach().float().cpu(), - expected_out.detach().float().cpu(), - rtol=1e-4, - atol=1e-4, - ) is None - - # Test backward. - out.mean().backward() - expected_out.mean().backward() - - # NOTE: If tensor parallelism is used different weights can be on - # different ranks. Gather the full grads to rank 0 to compare. - def gather(x): - m, n = x.shape - world_size = torch.distributed.get_world_size(group) - out = torch.empty(m * world_size, n, device=x.device, dtype=x.dtype) - torch.distributed.all_gather_into_tensor(out, x, group=group) - return out - - def permute(x): - esd = mpu.expert_sharding_degree(ep_args) - hsd = mpu.hidden_sharding_degree(ep_args) - out = x.view(hsd, esd, -1).transpose(1, 0).contiguous() - return out.view(num_experts * ffn_hidden_size, hidden_size) - - wp_w2_grad = gather(wp.experts.mlp.w2.grad) - ep_w2_grad = permute(gather(ep.experts.mlp.w2.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w2_grad.float().cpu(), - ep_w2_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - wp_w1_grad = gather(wp.experts.mlp.w1.grad) - ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w1_grad.float().cpu(), - ep_w1_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - # Verify the router weight gradient, which is not sharded. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - wp.router.layer.weight.grad.float().cpu(), - ep.router.layer.weight.grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None