diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 892cb91..6492dfd 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -68,6 +68,10 @@ class Arguments: int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + def __post_init__(self): if self.__getattribute__('mlp_impl') == 'grouped': grouped_gemm.assert_grouped_gemm_is_available() diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 9499870..a6deae0 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -7,6 +7,40 @@ from megablocks.layers import common from megablocks.layers.arguments import Arguments +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + # NOTE: To enable end-to-end benchmarking without convergence we # support a flag to force the router to assign tokens uniformly @@ -60,7 +94,9 @@ def forward(self, x: torch.Tensor): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) - scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) expert_weights, expert_indices = self._top_k(scores) if self.args.moe_normalize_expert_weights: expert_weights = expert_weights / torch.norm( diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 3d6565c..d8f34ba 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -10,6 +10,7 @@ from megablocks.layers.arguments import Arguments from megablocks.layers.dmoe import dMoE from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss +from megablocks.layers.router import batched_router_zloss, clear_router_zloss from tests.layers.architectures import FFN # min size: (1, 2, 128, 2, 1) @@ -50,6 +51,7 @@ def construct_moes( moe_capacity_factor: int = 1, moe_top_k: int = 1, mlp_impl: str = 'sparse', + moe_zloss_weight: float = 0, ): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( @@ -64,6 +66,7 @@ def construct_moes( mlp_impl=mlp_impl, fp16=False, bf16=True, + moe_zloss_weight=moe_zloss_weight, ) mlp = FFN(args) @@ -142,6 +145,39 @@ def test_dmoe_forward_backward( clear_load_balancing_loss() +@pytest.mark.gpu +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) +def test_dmoe_forward_backward_with_zloss( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): + x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() + x.requires_grad_(True) + + args, _, _, layer = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + mlp_impl=mlp_impl, + moe_zloss_weight=1e-3, + ) + + out, _ = layer(x) + assert out.shape == x.shape + loss = out.sum() + batched_load_balancing_loss(args) + batched_router_zloss(args) + loss.backward() + assert x.grad is not None + layer.zero_grad(set_to_none=True) + x.grad = None + clear_load_balancing_loss() + clear_router_zloss() + + @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) def test_dmoe_forward_vs_baseline( diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index ffd32cb..24d42c9 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -8,6 +8,7 @@ from megablocks.layers.arguments import Arguments from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss +from megablocks.layers.router import batched_router_zloss, clear_router_zloss from tests.layers.architectures import FFN _FORWARD_TESTS = ( @@ -33,11 +34,12 @@ def construct_moe( - hidden_size, - ffn_hidden_size, - moe_num_experts=1, - moe_capacity_factor=1, - moe_top_k=1, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int = 1, + moe_capacity_factor: int = 1, + moe_top_k: int = 1, + moe_zloss_weight: float = 0, ): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( @@ -47,6 +49,7 @@ def construct_moe( moe_capacity_factor=moe_capacity_factor, moe_top_k=moe_top_k, init_method=init_method, + moe_zloss_weight=moe_zloss_weight, ) mlp = FFN(args) @@ -109,6 +112,37 @@ def test_moe_forward_backward( clear_load_balancing_loss() +@pytest.mark.gpu +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) +def test_moe_forward_backward_with_zloss( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, +): + x = torch.randn(sl, bs, hs).half().cuda() + x.requires_grad_(True) + + args, _, layer = construct_moe( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_zloss_weight=1e-3, + ) + + out, _ = layer(x) + assert out.shape == x.shape + + loss = out.sum() + batched_load_balancing_loss(args) + loss.backward() + layer.zero_grad(set_to_none=True) + x.grad = None + clear_load_balancing_loss() + clear_router_zloss() + + @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):