diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 98d5a707..4968aa63 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -2,6 +2,7 @@ from megablocks.layers import mlp from megablocks.layers import moe from megablocks.layers import mpu +from megablocks.layers import router from megablocks.layers.arguments import Arguments import megablocks.ops as ops import numpy as np @@ -13,10 +14,10 @@ def promote_scalar(x): return x.view(1) if not len(x.size()) else x -class dMoE(moe.MoE): +class ParallelDroplessMLP(moe.ParallelMLP): def __init__(self, args : Arguments): - super(dMoE, self).__init__(args) + super(ParallelDroplessMLP, self).__init__(args) self.hidden_size = args.hidden_size self.ffn_hidden_size = mpu.features_per_rank(args) self.blocking = 128 @@ -307,3 +308,27 @@ def permute_and_compute( bins, expert_capactiy, top_k) + + +class dMoE(torch.nn.Module): + + def __init__(self, args : Arguments): + super(dMoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = ParallelDroplessMLP(args) + + def forward(self, x): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + sl, bs, hs = x.size() + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + return self.experts(x, scores, expert_weights, top_experts) diff --git a/megablocks/layers/dmoe_test.py b/megablocks/layers/dmoe_test.py index acdc4857..b932ef46 100644 --- a/megablocks/layers/dmoe_test.py +++ b/megablocks/layers/dmoe_test.py @@ -44,14 +44,14 @@ def test_modules( # Set the baseline parameters to match exactly. with torch.no_grad(): - ne, hs, fhs = moe_mlp.mlp.w1.size() - w1 = dmoe_mlp.mlp.w1.view([ne, fhs, hs]) - moe_mlp.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous()) - moe_mlp.mlp.w2.copy_(dmoe_mlp.mlp.w2.view([ne, fhs, hs])) + ne, hs, fhs = moe_mlp.experts.mlp.w1.size() + w1 = dmoe_mlp.experts.mlp.w1.view([ne, fhs, hs]) + moe_mlp.experts.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous()) + moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, hs])) moe_mlp.router.layer.weight.copy_(dmoe_mlp.router.layer.weight) if moe_num_experts == 1: - mlp.w1.copy_(moe_mlp.mlp.w1.squeeze()) - mlp.w2.copy_(moe_mlp.mlp.w2.squeeze()) + mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze()) + mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze()) return args, mlp, moe_mlp, dmoe_mlp # min size: (1, 2, 128, 2, 1) diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 0d752718..e3142726 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -65,8 +65,8 @@ def test_memory( # Calculate weight and gradient memory usage. weight_memory = 2 * ( layer.router.layer.weight.numel() + - layer.mlp.w1.numel() + - layer.mlp.w2.numel()) + layer.experts.mlp.w1.numel() + + layer.experts.mlp.w2.numel()) def grad_numel(x): if x.grad is not None: @@ -75,8 +75,8 @@ def grad_numel(x): grad_memory = 2 * ( grad_numel(layer.router.layer.weight) + - grad_numel(layer.mlp.w1) + - grad_numel(layer.mlp.w2)) + grad_numel(layer.experts.mlp.w1) + + grad_numel(layer.experts.mlp.w2)) weight_memory += grad_memory print("Weight Memory Allocated = {:0.0f}MiB".format( diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index adf52c57..f73f0aea 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -94,10 +94,14 @@ def batched_load_balancing_loss(args : Arguments): return scale * torch.dot(tokens_per_expert, expert_scores) -class MoE(torch.nn.Module): +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): def __init__(self, args : Arguments): - super(MoE, self).__init__() + super(ParallelMLP, self).__init__() self.args = args # Calculate the number of experts in total and the number of experts @@ -110,9 +114,6 @@ def __init__(self, args : Arguments): # so that we can pass it to radix sort. self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - # Token router. - self.router = router.LearnedRouter(args) - # Expert MLP. self.mlp = mlp.MLP(args) @@ -410,15 +411,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): self.args.quantize_scatter_num_bits) return x, tokens_per_expert.flatten() - def forward(self, x): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) + def forward(self, x, scores, expert_weights, top_experts): sl, bs, hs = x.size() - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - # Compute the experts. x, tokens_per_expert = self.forward_fn( x, expert_weights, top_experts) @@ -429,3 +424,27 @@ def forward(self, x): return x, self.bias return x + self.bias return x + + +class MoE(torch.nn.Module): + + def __init__(self, args : Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = ParallelMLP(args) + + def forward(self, x): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + sl, bs, hs = x.size() + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + return self.experts(x, scores, expert_weights, top_experts) diff --git a/megablocks/layers/moe_test.py b/megablocks/layers/moe_test.py index 43fc9d63..08f6bc8a 100644 --- a/megablocks/layers/moe_test.py +++ b/megablocks/layers/moe_test.py @@ -32,8 +32,8 @@ def test_modules( # Set the baseline parameters to match exactly. if moe_num_experts == 1: with torch.no_grad(): - mlp.w1.copy_(moe_mlp.mlp.w1.squeeze()) - mlp.w2.copy_(moe_mlp.mlp.w2.squeeze()) + mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze()) + mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze()) return args, mlp, moe_mlp @@ -126,8 +126,8 @@ def testMoE_ForwardBackwardVersusDense(self, bs, sl, hs): out, _ = moe_mlp(x) loss = out.sum() loss.backward() - w1_grad = moe_mlp.mlp.w1.grad.detach().squeeze() - w2_grad = moe_mlp.mlp.w2.grad.detach().squeeze() + w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze() + w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze() moe_mlp.zero_grad(set_to_none=True) x.grad = None moe.clear_load_balancing_loss() diff --git a/megablocks/layers/parallelism_test.py b/megablocks/layers/parallelism_test.py index 884cb826..aa3d6584 100644 --- a/megablocks/layers/parallelism_test.py +++ b/megablocks/layers/parallelism_test.py @@ -96,16 +96,16 @@ def permute(x): out = x.view(hsd, esd, -1).transpose(1, 0).contiguous() return out.view(num_experts * ffn_hidden_size, hidden_size) - wp_w2_grad = gather(wp.mlp.w2.grad) - ep_w2_grad = permute(gather(ep.mlp.w2.grad)) + wp_w2_grad = gather(wp.experts.mlp.w2.grad) + ep_w2_grad = permute(gather(ep.experts.mlp.w2.grad)) if rank == 0: np.testing.assert_allclose( wp_w2_grad.float().cpu(), ep_w2_grad.float().cpu(), rtol=1e-5, atol=1e-5) - wp_w1_grad = gather(wp.mlp.w1.grad) - ep_w1_grad = permute(gather(ep.mlp.w1.grad)) + wp_w1_grad = gather(wp.experts.mlp.w1.grad) + ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) if rank == 0: np.testing.assert_allclose( wp_w1_grad.float().cpu(),