From abc0638ef05f3b35f74368c5919838a9237dc80b Mon Sep 17 00:00:00 2001 From: Jose Javier <26491792+josejg@users.noreply.github.com> Date: Sun, 8 Sep 2024 09:54:42 +0200 Subject: [PATCH] Router zloss --- megablocks/layers/arguments.py | 4 ++++ megablocks/layers/router.py | 36 +++++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 892cb91..0b6d666 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -68,6 +68,10 @@ class Arguments: int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32 : bool = False + def __post_init__(self): if self.__getattribute__('mlp_impl') == 'grouped': grouped_gemm.assert_grouped_gemm_is_available() diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 9499870..a0b4b4e 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -7,6 +7,38 @@ from megablocks.layers import common from megablocks.layers.arguments import Arguments +_ROUTER_LOGITS = [] + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + +def batched_router_zloss(args : Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn("Call to batched_router_zloss, but moe_zloss_weight=0") + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() + for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + # NOTE: To enable end-to-end benchmarking without convergence we # support a flag to force the router to assign tokens uniformly @@ -60,7 +92,9 @@ def forward(self, x: torch.Tensor): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) - scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) expert_weights, expert_indices = self._top_k(scores) if self.args.moe_normalize_expert_weights: expert_weights = expert_weights / torch.norm(