Skip to content

Commit

Permalink
Merge pull request #10 from stanford-futuredata/benchmark-load-balance
Browse files Browse the repository at this point in the history
Add flag to force uniform assignment to experts for load balancing.
  • Loading branch information
tgale96 authored Aug 1, 2023
2 parents c0cb172 + bfce1a5 commit 5501101
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
3 changes: 3 additions & 0 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class Arguments:
init_method : InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
output_layer_init_method : InitFn = init_method

# Benchmarking arguments.
uniform_expert_assignment : bool = False


def from_megatron(megatron_args):
args = Arguments()
Expand Down
32 changes: 29 additions & 3 deletions megablocks/layers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
from megablocks.layers.arguments import Arguments
import torch


# NOTE: To enable end-to-end benchmarking without convergence we
# support a flag to force the router to assign tokens uniformly
# across the experts. We do this with a custom autograd operation
# so that PyTorch still executes the full set of router operation.
class _UniformExpertAssignment(torch.autograd.Function):


@staticmethod
def forward(ctx, x, num_experts):
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
out = torch.remainder(out, num_experts)
return out.view(x.shape)
_uniform_expert_assignment = _UniformExpertAssignment.apply


class LearnedRouter(torch.nn.Module):

def __init__(self, args : Arguments):
Expand All @@ -27,12 +43,22 @@ def jitter(self, x):
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
return low + noise * (high - low)

def _top_k(self, scores):
if self.args.moe_top_k == 1:
return scores.max(dim=-1)
return torch.topk(scores, self.args.moe_top_k, dim=-1)


def forward(self, x):
if self.training and self.args.moe_jitter_eps is not None:
x = x * self.jitter(x)

sl, bs, hs = x.size()
scores = self.layer(x.view(-1, hs)).softmax(dim=-1)
if self.args.moe_top_k == 1:
return scores, *scores.max(dim=-1)
return scores, *torch.topk(scores, self.args.moe_top_k, dim=-1)
expert_weights, expert_indices = self._top_k(scores)

expert_indices = (
_uniform_expert_assignment(expert_indices, self.args.moe_num_experts)
if self.args.uniform_expert_assignment else expert_indices
)
return scores, expert_weights, expert_indices

0 comments on commit 5501101

Please sign in to comment.