From d2774b2caf1db5e5c120e88fbf6434f065ff87fc Mon Sep 17 00:00:00 2001 From: Michael Gokhman Date: Thu, 11 Jul 2024 21:53:01 +0300 Subject: [PATCH] dont save moe lb-loss tensors if args.moe_loss_weight=0 (#119) it takes GPU memory, and can also cause a leak if clear_load_balancing_loss() is not called Co-authored-by: Michael Gokhman --- megablocks/layers/moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index b264a3aa..b3c64e43 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -29,6 +29,9 @@ def clear_load_balancing_loss(): def batched_load_balancing_loss(args : Arguments): + if args.moe_loss_weight == 0: + return 0.0 + # tokens_per_expert[i].shape = (num_experts) # expert_scores[i].shape = (tokens, num_experts) tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) @@ -424,7 +427,7 @@ def forward(self, x, scores, expert_weights, top_experts): # Compute the experts. x, tokens_per_expert = self.forward_fn( x, expert_weights, top_experts) - if self.training: + if self.training and self.args.moe_loss_weight > 0: save_load_balancing_loss((tokens_per_expert, scores)) x = x.view(in_shape) if self.bias is not None: