diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 892cb91..0d5399b 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -35,6 +35,7 @@ class Arguments: moe_capacity_factor: int = 1 moe_normalize_expert_weights: Optional[Union[int, float]] = None moe_loss_weight: float = 0.1 + moe_zloss_weight: float = 0.001 moe_jitter_eps: Optional[float] = None moe_lbl_in_fp32: bool = False diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 9ba5edb..dc5b845 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -35,8 +35,10 @@ def batched_load_balancing_loss(args: Arguments): # tokens_per_expert[i].shape = (num_experts) # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + # tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + tokens_per_expert, expert_scores, expert_logits = zip(*get_load_balancing_loss()) num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage @@ -75,6 +77,7 @@ def batched_load_balancing_loss(args: Arguments): else: expert_scores = expert_scores.sum(dim=0) tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + expert_logits = torch.cat(expert_logits, dim=0).to(expert_scores.dtype) expected_values = num_layers_per_pipeline_stage * args.moe_num_experts assert tokens_per_expert.numel() == expected_values @@ -86,7 +89,8 @@ def batched_load_balancing_loss(args: Arguments): scale_numerator = (args.moe_num_experts * args.moe_loss_weight) scale_denominator = (args.num_layers * tokens * args.moe_top_k) scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) + zloss = (torch.logsumexp(expert_logits, dim=-1) ** 2).sum() / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores), args.moe_zloss_weight * zloss # NOTE: This class defines MoE expert computation, including expert model parallel @@ -422,13 +426,13 @@ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, t x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + def forward(self, x: torch.Tensor, scores: torch.Tensor, logits: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): in_shape = x.size() # Compute the experts. x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) + save_load_balancing_loss((tokens_per_expert, scores, logits)) x = x.view(in_shape) if self.bias is not None: if self.args.return_bias: @@ -462,10 +466,10 @@ def forward(self, x: torch.Tensor): x = common.cast_if_autocast_enabled(x) # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) + scores, logits, expert_weights, top_experts = self.router(x) # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) + out = self.experts(x, scores, logits, expert_weights, top_experts) if self.shared_expert is not None: shared_expert_out = self.shared_expert(x) out = self.shared_expert.add_experts_sharedexpert( diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 9499870..a73f4a4 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -59,8 +59,9 @@ def _top_k(self, scores: torch.Tensor): def forward(self, x: torch.Tensor): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) - - scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + # scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + logits = self.layer(x.view(-1, x.shape[-1])) + scores = logits.softmax(dim=-1) expert_weights, expert_indices = self._top_k(scores) if self.args.moe_normalize_expert_weights: expert_weights = expert_weights / torch.norm( @@ -76,4 +77,4 @@ def forward(self, x: torch.Tensor): self.args.moe_num_experts, ) if self.args.uniform_expert_assignment else expert_indices ) - return scores, expert_weights, expert_indices + return scores, logits, expert_weights, expert_indices diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 3d6565c..7f54d56 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -134,7 +134,7 @@ def test_dmoe_forward_backward( out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args)[0] loss.backward() assert x.grad is not None layer.zero_grad(set_to_none=True) diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index ffd32cb..37527d6 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -102,7 +102,7 @@ def test_moe_forward_backward( out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args)[0] loss.backward() layer.zero_grad(set_to_none=True) x.grad = None