Skip to content

Commit

Permalink
modify
Browse files Browse the repository at this point in the history
  • Loading branch information
Eitan Turok committed Aug 8, 2024
1 parent 3b342ff commit 43f389a
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 32 deletions.
3 changes: 0 additions & 3 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(self, args: Arguments):
self._should_set_parallelism_attribute,
)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GLU.',)

def forward(self, x, topo):
if self.args.memory_optimized_mlp:
raise NotImplementedError(
Expand Down
24 changes: 3 additions & 21 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,7 @@ def create_dmoe_expert_weights(
init_method,
)
weights = weights.view([-1, columns])
rows, columns = weights.shape

if not args.moe_weight_parallelism:
return weights

# Caclculate the number of rows on this weight parallel partition.
# 'rows' must be divisible by weight parallel world size.
weight_parallel_world_size = mpu.get_weight_parallel_world_size(args)
assert (rows % weight_parallel_world_size) == 0
num_rows_per_rank = rows // weight_parallel_world_size
rank = mpu.get_weight_parallel_rank(args)
start_row = rank * num_rows_per_rank
end_row = (rank + 1) * num_rows_per_rank
return weights[start_row:end_row]
return weights


class MemoryOptimizedMLP(torch.autograd.Function):
Expand Down Expand Up @@ -371,7 +358,7 @@ def __init__(self, args: Arguments):
),
)

self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism)
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
mpu.set_expert_model_parallel_attributes(
self.w1,
self._should_set_parallelism_attribute,
Expand Down Expand Up @@ -414,9 +401,7 @@ def parallel_forward(self, x, topo):
def forward(self, x, topo):
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
if self.args.moe_weight_parallelism:
return self.parallel_forward(x, topo)
elif self.args.memory_optimized_mlp:
if self.args.memory_optimized_mlp:
return memory_optimized_mlp(
x,
w1,
Expand Down Expand Up @@ -542,9 +527,6 @@ def forward(self, x, tokens_per_expert):
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',)

if self.args.memory_optimized_mlp:
return memory_optimized_grouped_mlp(
x,
Expand Down
8 changes: 0 additions & 8 deletions megablocks/layers/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ def copy_expert_model_parallel_attributes(
)


def get_weight_parallel_world_size(args: Arguments) -> int:
return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1)


def get_weight_parallel_rank(args: Arguments) -> int:
return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0)


def synchronized_print(group, *x):
world_size = torch.distributed.get_world_size(group)
rank = torch.distributed.get_rank(group)
Expand Down

0 comments on commit 43f389a

Please sign in to comment.