Skip to content

Commit

Permalink
dont save moe lb-loss tensors if args.moe_loss_weight=0 (#119)
Browse files Browse the repository at this point in the history
it takes GPU memory, and can also cause a leak if
clear_load_balancing_loss() is not called

Co-authored-by: Michael Gokhman <[email protected]>
  • Loading branch information
michael-go and Michael Gokhman authored Jul 11, 2024
1 parent f1a83bd commit d2774b2
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion megablocks/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def clear_load_balancing_loss():


def batched_load_balancing_loss(args : Arguments):
if args.moe_loss_weight == 0:
return 0.0

# tokens_per_expert[i].shape = (num_experts)
# expert_scores[i].shape = (tokens, num_experts)
tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
Expand Down Expand Up @@ -424,7 +427,7 @@ def forward(self, x, scores, expert_weights, top_experts):
# Compute the experts.
x, tokens_per_expert = self.forward_fn(
x, expert_weights, top_experts)
if self.training:
if self.training and self.args.moe_loss_weight > 0:
save_load_balancing_loss((tokens_per_expert, scores))
x = x.view(in_shape)
if self.bias is not None:
Expand Down

0 comments on commit d2774b2

Please sign in to comment.