From 24bae3ecb61de97842c587b28b184374ffeeae46 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 10 Jun 2024 23:22:15 +0000 Subject: [PATCH 1/5] Add zloss --- megablocks/layers/arguments.py | 1 + megablocks/layers/moe.py | 15 +++++++++------ megablocks/layers/router.py | 7 ++++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 9b6c49bb..c14b1721 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -29,6 +29,7 @@ class Arguments: moe_capacity_factor : int = 1 moe_normalize_expert_weights : Optional[Union[int, float]] = None moe_loss_weight : float = 0.1 + moe_zloss_weight : float = 0.001 moe_jitter_eps : Optional[float] = None moe_lbl_in_fp32 : bool = False diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index b264a3aa..9e7260d1 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -31,7 +31,8 @@ def clear_load_balancing_loss(): def batched_load_balancing_loss(args : Arguments): # tokens_per_expert[i].shape = (num_experts) # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + # tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + tokens_per_expert, expert_scores, expert_logits = zip(*get_load_balancing_loss()) num_layers_per_pipeline_stage = ( args.num_layers // args.pipeline_model_parallel_size) if args.num_layers_per_virtual_pipeline_stage is not None: @@ -74,6 +75,7 @@ def batched_load_balancing_loss(args : Arguments): else: expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0) tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + expert_logits = torch.cat(expert_logits, dim=0).to(expert_scores.dtype) expected_values = num_layers_per_pipeline_stage * args.moe_num_experts assert tokens_per_expert.numel() == expected_values @@ -92,7 +94,8 @@ def batched_load_balancing_loss(args : Arguments): args.moe_top_k ) scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) + zloss = (torch.log(torch.exp(expert_logits).sum(dim=-1)) ** 2).sum() / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + args.moe_zloss_weight * zloss # NOTE: This class defines MoE expert computation, including expert model parallel @@ -418,14 +421,14 @@ def parallel_forward_once(self, x, expert_weights, top_experts): self.top_k) return x, tokens_per_expert.flatten() - def forward(self, x, scores, expert_weights, top_experts): + def forward(self, x, scores, logits, expert_weights, top_experts): in_shape = x.size() # Compute the experts. x, tokens_per_expert = self.forward_fn( x, expert_weights, top_experts) if self.training: - save_load_balancing_loss((tokens_per_expert, scores)) + save_load_balancing_loss((tokens_per_expert, scores, logits)) x = x.view(in_shape) if self.bias is not None: if self.args.return_bias: @@ -459,10 +462,10 @@ def forward(self, x): x = common.cast_if_autocast_enabled(x) # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) + scores, logits, expert_weights, top_experts = self.router(x) # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) + out = self.experts(x, scores, logits, expert_weights, top_experts) if self.shared_expert is not None: shared_expert_out = self.shared_expert(x) out = self.shared_expert.add_experts_sharedexpert(shared_expert_out, out) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index e1abddf0..ec598728 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -51,8 +51,9 @@ def _top_k(self, scores): def forward(self, x): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) - - scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + # scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + logits = self.layer(x.view(-1, x.shape[-1])) + scores = logits.softmax(dim=-1) expert_weights, expert_indices = self._top_k(scores) if self.args.moe_normalize_expert_weights: expert_weights = expert_weights / torch.norm( @@ -62,4 +63,4 @@ def forward(self, x): _uniform_expert_assignment(expert_indices, self.args.moe_num_experts) if self.args.uniform_expert_assignment else expert_indices ) - return scores, expert_weights, expert_indices + return scores, logits, expert_weights, expert_indices From 2783cd41aafd5561f3880a8ba8d506fadb07e50a Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Wed, 19 Jun 2024 21:51:47 -0700 Subject: [PATCH 2/5] Allow logging zloss --- megablocks/layers/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 9e7260d1..dac13413 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -95,7 +95,7 @@ def batched_load_balancing_loss(args : Arguments): ) scale = scale_numerator / scale_denominator zloss = (torch.log(torch.exp(expert_logits).sum(dim=-1)) ** 2).sum() / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) + args.moe_zloss_weight * zloss + return scale * torch.dot(tokens_per_expert, expert_scores), args.moe_zloss_weight * zloss # NOTE: This class defines MoE expert computation, including expert model parallel From e430ad707bed4d45016f315da9372e16acb55a1c Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Sat, 10 Aug 2024 16:42:21 -0700 Subject: [PATCH 3/5] Use torch.logsumexp --- megablocks/layers/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index e0e260ce..22202178 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -100,7 +100,7 @@ def batched_load_balancing_loss(args : Arguments): args.moe_top_k ) scale = scale_numerator / scale_denominator - zloss = (torch.log(torch.exp(expert_logits).sum(dim=-1)) ** 2).sum() / scale_denominator + zloss = (torch.logsumexp(expert_logits, dim=-1) ** 2).sum() / scale_denominator return scale * torch.dot(tokens_per_expert, expert_scores), args.moe_zloss_weight * zloss From a3a35b769d21dd2c0fbe8f2ba1adbd8eadef83f7 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Sat, 7 Sep 2024 15:11:52 -0700 Subject: [PATCH 4/5] Fix rtn type --- tests/layers/dmoe_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 3d6565c8..7f54d565 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -134,7 +134,7 @@ def test_dmoe_forward_backward( out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args)[0] loss.backward() assert x.grad is not None layer.zero_grad(set_to_none=True) From b39400be535c9581ae877f8d79a66280580611f0 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Sat, 7 Sep 2024 15:18:46 -0700 Subject: [PATCH 5/5] Fix rtn type --- tests/layers/moe_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index ffd32cbf..37527d68 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -102,7 +102,7 @@ def test_moe_forward_backward( out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args)[0] loss.backward() layer.zero_grad(set_to_none=True) x.grad = None