Skip to content

Commit

Permalink
[Mixtral] Reverted output dtype of ttnn.eq() (#9509)
Browse files Browse the repository at this point in the history
#5337: [Mixtral] Reverted output dtype of ttnn.eq()
  • Loading branch information
mtairum authored Jun 18, 2024
1 parent 142c8fe commit 5e24606
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, inputs):
gate_logits_1SB8 = ttnn.add(gate_logits_1SB8, self.top8_mask_11B_64)
ttl_topk_values, ttl_topk_indices = ttnn.experimental.operations.primary.topk(gate_logits_1SB8, 32)
ttl_topk_values = ttnn.add(ttl_topk_values, self.top2_mask_11BB)
mask_B2 = ttnn.eq(self.expert_mask_11BB, ttl_topk_indices, dtype=ttnn.bfloat16)
mask_B2 = ttnn.eq(self.expert_mask_11BB, ttl_topk_indices)
weights_1SB1 = ttnn.sum(ttnn.softmax(ttl_topk_values, dim=-1) * mask_B2, dim=3)

# MLP and masking
Expand Down

0 comments on commit 5e24606

Please sign in to comment.