Skip to content

Commit

Permalink
remove weight parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
Eitan Turok committed Aug 9, 2024
1 parent 6adb1fb commit 3016e09
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 191 deletions.
2 changes: 0 additions & 2 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class Arguments:
# Parallelism arguments.
moe_expert_model_parallelism: bool = False
expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None
moe_weight_parallelism: bool = False
weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None
pipeline_model_parallel_size: int = 1
num_layers_per_virtual_pipeline_stage: Optional[int] = None

Expand Down
3 changes: 0 additions & 3 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(self, args: Arguments):
self._should_set_parallelism_attribute,
)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GLU.',)

def forward(self, x, topo):
if self.args.memory_optimized_mlp:
raise NotImplementedError(
Expand Down
29 changes: 5 additions & 24 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,7 @@ def create_dmoe_expert_weights(
init_method,
)
weights = weights.view([-1, columns])
rows, columns = weights.shape

if not args.moe_weight_parallelism:
return weights

# Caclculate the number of rows on this weight parallel partition.
# 'rows' must be divisible by weight parallel world size.
weight_parallel_world_size = mpu.get_weight_parallel_world_size(args)
assert (rows % weight_parallel_world_size) == 0
num_rows_per_rank = rows // weight_parallel_world_size
rank = mpu.get_weight_parallel_rank(args)
start_row = rank * num_rows_per_rank
end_row = (rank + 1) * num_rows_per_rank
return weights[start_row:end_row]
return weights


class MemoryOptimizedMLP(torch.autograd.Function):
Expand Down Expand Up @@ -323,8 +310,7 @@ class SparseMLP(torch.nn.Module):
def __init__(self, args: Arguments):
super().__init__()
self.args = args
self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) //
mpu.get_weight_parallel_world_size(args))
self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)

self.w1 = torch.nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -371,7 +357,7 @@ def __init__(self, args: Arguments):
),
)

self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism)
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
mpu.set_expert_model_parallel_attributes(
self.w1,
self._should_set_parallelism_attribute,
Expand All @@ -391,7 +377,7 @@ def scale_grad(self, w):
return scale_gradient(w, self.gradient_scale)

def parallel_forward(self, x, topo):
group = self.args.weight_parallel_group
group = None
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
if self.args.memory_optimized_mlp:
if self.args.activation_fn is not DEFAULT_ACTIVATION_FN:
Expand All @@ -414,9 +400,7 @@ def parallel_forward(self, x, topo):
def forward(self, x, topo):
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
if self.args.moe_weight_parallelism:
return self.parallel_forward(x, topo)
elif self.args.memory_optimized_mlp:
if self.args.memory_optimized_mlp:
return memory_optimized_mlp(
x,
w1,
Expand Down Expand Up @@ -542,9 +526,6 @@ def forward(self, x, tokens_per_expert):
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',)

if self.args.memory_optimized_mlp:
return memory_optimized_grouped_mlp(
x,
Expand Down
9 changes: 0 additions & 9 deletions megablocks/layers/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ def copy_expert_model_parallel_attributes(
getattr(source_tensor, 'expert_model_parallel'),
)


def get_weight_parallel_world_size(args: Arguments) -> int:
return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1)


def get_weight_parallel_rank(args: Arguments) -> int:
return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0)


def synchronized_print(group, *x):
world_size = torch.distributed.get_world_size(group)
rank = torch.distributed.get_rank(group)
Expand Down
153 changes: 0 additions & 153 deletions tests/layers/parallelism_test.py

This file was deleted.

0 comments on commit 3016e09

Please sign in to comment.