From 5e246062d96bacc27f8ad5f334c6ac458b0a1261 Mon Sep 17 00:00:00 2001 From: Miguel Tairum <150826086+mtairum@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:51:59 +0100 Subject: [PATCH] [Mixtral] Reverted output dtype of ttnn.eq() (#9509) #5337: [Mixtral] Reverted output dtype of ttnn.eq() --- models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py index dd127745fcf..6664ad227e2 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py @@ -99,7 +99,7 @@ def forward(self, inputs): gate_logits_1SB8 = ttnn.add(gate_logits_1SB8, self.top8_mask_11B_64) ttl_topk_values, ttl_topk_indices = ttnn.experimental.operations.primary.topk(gate_logits_1SB8, 32) ttl_topk_values = ttnn.add(ttl_topk_values, self.top2_mask_11BB) - mask_B2 = ttnn.eq(self.expert_mask_11BB, ttl_topk_indices, dtype=ttnn.bfloat16) + mask_B2 = ttnn.eq(self.expert_mask_11BB, ttl_topk_indices) weights_1SB1 = ttnn.sum(ttnn.softmax(ttl_topk_values, dim=-1) * mask_B2, dim=3) # MLP and masking