Skip to content

Commit

Permalink
Use torch.logsumexp
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff authored Aug 10, 2024
1 parent b3c25d7 commit e430ad7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion megablocks/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def batched_load_balancing_loss(args : Arguments):
args.moe_top_k
)
scale = scale_numerator / scale_denominator
zloss = (torch.log(torch.exp(expert_logits).sum(dim=-1)) ** 2).sum() / scale_denominator
zloss = (torch.logsumexp(expert_logits, dim=-1) ** 2).sum() / scale_denominator
return scale * torch.dot(tokens_per_expert, expert_scores), args.moe_zloss_weight * zloss


Expand Down

0 comments on commit e430ad7

Please sign in to comment.