Skip to content

Commit

Permalink
remove weight parallelism (#137)
Browse files Browse the repository at this point in the history
* remove weight parallelism

* fix linting

* remove parallel forward from mlp

* remove weight parallel

* cleanup
  • Loading branch information
eitanturok authored Aug 12, 2024
1 parent f87b26f commit 27d3d2c
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 628 deletions.
2 changes: 0 additions & 2 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class Arguments:
# Parallelism arguments.
moe_expert_model_parallelism: bool = False
expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None
moe_weight_parallelism: bool = False
weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None
pipeline_model_parallel_size: int = 1
num_layers_per_virtual_pipeline_stage: Optional[int] = None

Expand Down
3 changes: 0 additions & 3 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(self, args: Arguments):
self._should_set_parallelism_attribute,
)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GLU.',)

def forward(self, x, topo):
if self.args.memory_optimized_mlp:
raise NotImplementedError(
Expand Down
50 changes: 4 additions & 46 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from megablocks import grouped_gemm_util as gg
from megablocks.layers import common, gelu, mpu
from megablocks.layers import weight_parallel as wp
from megablocks.layers.activation_fn import act_fn
from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn

Expand Down Expand Up @@ -180,21 +179,7 @@ def create_dmoe_expert_weights(
columns,
init_method,
)
weights = weights.view([-1, columns])
rows, columns = weights.shape

if not args.moe_weight_parallelism:
return weights

# Caclculate the number of rows on this weight parallel partition.
# 'rows' must be divisible by weight parallel world size.
weight_parallel_world_size = mpu.get_weight_parallel_world_size(args)
assert (rows % weight_parallel_world_size) == 0
num_rows_per_rank = rows // weight_parallel_world_size
rank = mpu.get_weight_parallel_rank(args)
start_row = rank * num_rows_per_rank
end_row = (rank + 1) * num_rows_per_rank
return weights[start_row:end_row]
return weights.view([-1, columns])


class MemoryOptimizedMLP(torch.autograd.Function):
Expand Down Expand Up @@ -323,8 +308,7 @@ class SparseMLP(torch.nn.Module):
def __init__(self, args: Arguments):
super().__init__()
self.args = args
self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) //
mpu.get_weight_parallel_world_size(args))
self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)

self.w1 = torch.nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -371,7 +355,7 @@ def __init__(self, args: Arguments):
),
)

self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism)
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
mpu.set_expert_model_parallel_attributes(
self.w1,
self._should_set_parallelism_attribute,
Expand All @@ -390,33 +374,10 @@ def scale_grad(self, w):
return w
return scale_gradient(w, self.gradient_scale)

def parallel_forward(self, x, topo):
group = self.args.weight_parallel_group
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
if self.args.memory_optimized_mlp:
if self.args.activation_fn is not DEFAULT_ACTIVATION_FN:
raise NotImplementedError(
f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.',
)
return wp.memory_optimized_weight_parallel_mlp(
x,
w1,
w2,
topo,
group,
)

# Compute the MLP.
x = wp.sdd_nt(x, w1, topo, group)
activation_fn_out = act_fn(x, self.args.activation_fn)
return wp.dsd_nn(activation_fn_out, w2, group)

def forward(self, x, topo):
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
if self.args.moe_weight_parallelism:
return self.parallel_forward(x, topo)
elif self.args.memory_optimized_mlp:
if self.args.memory_optimized_mlp:
return memory_optimized_mlp(
x,
w1,
Expand Down Expand Up @@ -542,9 +503,6 @@ def forward(self, x, tokens_per_expert):
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',)

if self.args.memory_optimized_mlp:
return memory_optimized_grouped_mlp(
x,
Expand Down
8 changes: 0 additions & 8 deletions megablocks/layers/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ def copy_expert_model_parallel_attributes(
)


def get_weight_parallel_world_size(args: Arguments) -> int:
return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1)


def get_weight_parallel_rank(args: Arguments) -> int:
return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0)


def synchronized_print(group, *x):
world_size = torch.distributed.get_world_size(group)
rank = torch.distributed.get_rank(group)
Expand Down
Loading

0 comments on commit 27d3d2c

Please sign in to comment.