From 43f389a3862cf7464db5dc1b79a37bffc936761e Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 20:50:07 +0000 Subject: [PATCH] modify --- megablocks/layers/glu.py | 3 --- megablocks/layers/mlp.py | 24 +++--------------------- megablocks/layers/mpu.py | 8 -------- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index fa888a6..4654576 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -44,9 +44,6 @@ def __init__(self, args: Arguments): self._should_set_parallelism_attribute, ) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GLU.',) - def forward(self, x, topo): if self.args.memory_optimized_mlp: raise NotImplementedError( diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 1cae4fb..00bc18b 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -181,20 +181,7 @@ def create_dmoe_expert_weights( init_method, ) weights = weights.view([-1, columns]) - rows, columns = weights.shape - - if not args.moe_weight_parallelism: - return weights - - # Caclculate the number of rows on this weight parallel partition. - # 'rows' must be divisible by weight parallel world size. - weight_parallel_world_size = mpu.get_weight_parallel_world_size(args) - assert (rows % weight_parallel_world_size) == 0 - num_rows_per_rank = rows // weight_parallel_world_size - rank = mpu.get_weight_parallel_rank(args) - start_row = rank * num_rows_per_rank - end_row = (rank + 1) * num_rows_per_rank - return weights[start_row:end_row] + return weights class MemoryOptimizedMLP(torch.autograd.Function): @@ -371,7 +358,7 @@ def __init__(self, args: Arguments): ), ) - self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism mpu.set_expert_model_parallel_attributes( self.w1, self._should_set_parallelism_attribute, @@ -414,9 +401,7 @@ def parallel_forward(self, x, topo): def forward(self, x, topo): w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.moe_weight_parallelism: - return self.parallel_forward(x, topo) - elif self.args.memory_optimized_mlp: + if self.args.memory_optimized_mlp: return memory_optimized_mlp( x, w1, @@ -542,9 +527,6 @@ def forward(self, x, tokens_per_expert): w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',) - if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( x, diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 6aa0015..239f75f 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -42,14 +42,6 @@ def copy_expert_model_parallel_attributes( ) -def get_weight_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1) - - -def get_weight_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0) - - def synchronized_print(group, *x): world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group)