Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove weight parallelism #137

Merged
merged 6 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class Arguments:
# Parallelism arguments.
moe_expert_model_parallelism: bool = False
expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None
moe_weight_parallelism: bool = False
weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None
pipeline_model_parallel_size: int = 1
num_layers_per_virtual_pipeline_stage: Optional[int] = None

Expand Down
3 changes: 0 additions & 3 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(self, args: Arguments):
self._should_set_parallelism_attribute,
)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GLU.',)

def forward(self, x, topo):
if self.args.memory_optimized_mlp:
raise NotImplementedError(
Expand Down
50 changes: 4 additions & 46 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from megablocks import grouped_gemm_util as gg
from megablocks.layers import common, gelu, mpu
from megablocks.layers import weight_parallel as wp
from megablocks.layers.activation_fn import act_fn
from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn

Expand Down Expand Up @@ -180,21 +179,7 @@ def create_dmoe_expert_weights(
columns,
init_method,
)
weights = weights.view([-1, columns])
rows, columns = weights.shape

if not args.moe_weight_parallelism:
return weights

# Caclculate the number of rows on this weight parallel partition.
# 'rows' must be divisible by weight parallel world size.
weight_parallel_world_size = mpu.get_weight_parallel_world_size(args)
assert (rows % weight_parallel_world_size) == 0
num_rows_per_rank = rows // weight_parallel_world_size
rank = mpu.get_weight_parallel_rank(args)
start_row = rank * num_rows_per_rank
end_row = (rank + 1) * num_rows_per_rank
return weights[start_row:end_row]
return weights.view([-1, columns])


class MemoryOptimizedMLP(torch.autograd.Function):
Expand Down Expand Up @@ -323,8 +308,7 @@ class SparseMLP(torch.nn.Module):
def __init__(self, args: Arguments):
super().__init__()
self.args = args
self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) //
mpu.get_weight_parallel_world_size(args))
self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)

self.w1 = torch.nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -371,7 +355,7 @@ def __init__(self, args: Arguments):
),
)

self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism)
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
mpu.set_expert_model_parallel_attributes(
self.w1,
self._should_set_parallelism_attribute,
Expand All @@ -390,33 +374,10 @@ def scale_grad(self, w):
return w
return scale_gradient(w, self.gradient_scale)

def parallel_forward(self, x, topo):
group = self.args.weight_parallel_group
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
if self.args.memory_optimized_mlp:
if self.args.activation_fn is not DEFAULT_ACTIVATION_FN:
raise NotImplementedError(
f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.',
)
return wp.memory_optimized_weight_parallel_mlp(
x,
w1,
w2,
topo,
group,
)

# Compute the MLP.
x = wp.sdd_nt(x, w1, topo, group)
activation_fn_out = act_fn(x, self.args.activation_fn)
return wp.dsd_nn(activation_fn_out, w2, group)

def forward(self, x, topo):
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
if self.args.moe_weight_parallelism:
return self.parallel_forward(x, topo)
elif self.args.memory_optimized_mlp:
if self.args.memory_optimized_mlp:
return memory_optimized_mlp(
x,
w1,
Expand Down Expand Up @@ -542,9 +503,6 @@ def forward(self, x, tokens_per_expert):
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)

if self.args.moe_weight_parallelism:
raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',)

if self.args.memory_optimized_mlp:
return memory_optimized_grouped_mlp(
x,
Expand Down
8 changes: 0 additions & 8 deletions megablocks/layers/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ def copy_expert_model_parallel_attributes(
)


def get_weight_parallel_world_size(args: Arguments) -> int:
return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1)


def get_weight_parallel_rank(args: Arguments) -> int:
return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0)


def synchronized_print(group, *x):
world_size = torch.distributed.get_world_size(group)
rank = torch.distributed.get_rank(group)
Expand Down
Loading
Loading