diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 59880b40..c622d736 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -35,6 +35,9 @@ class Arguments: init_method : InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) output_layer_init_method : InitFn = init_method + # Benchmarking arguments. + uniform_expert_assignment : bool = False + def from_megatron(megatron_args): args = Arguments() diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 6a14e99b..b4720439 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -2,6 +2,22 @@ from megablocks.layers.arguments import Arguments import torch + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + + @staticmethod + def forward(ctx, x, num_experts): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) +_uniform_expert_assignment = _UniformExpertAssignment.apply + + class LearnedRouter(torch.nn.Module): def __init__(self, args : Arguments): @@ -27,12 +43,22 @@ def jitter(self, x): noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low) + def _top_k(self, scores): + if self.args.moe_top_k == 1: + return scores.max(dim=-1) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) sl, bs, hs = x.size() scores = self.layer(x.view(-1, hs)).softmax(dim=-1) - if self.args.moe_top_k == 1: - return scores, *scores.max(dim=-1) - return scores, *torch.topk(scores, self.args.moe_top_k, dim=-1) + expert_weights, expert_indices = self._top_k(scores) + + expert_indices = ( + _uniform_expert_assignment(expert_indices, self.args.moe_num_experts) + if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices