From 9efcb57c9dc6d61fba4afd007c3423223fc8eb38 Mon Sep 17 00:00:00 2001 From: mtairum Date: Tue, 29 Oct 2024 17:14:20 +0000 Subject: [PATCH] #13368: fix mlp reshard --- models/demos/llama3/tt/llama_mlp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index cccbaf7864ef..f06e2ff63f1f 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -107,8 +107,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: mode == "decode" ): # TODO Add a check for a match between FF1/FF3 and FF2 memory configs. If they match avoid doing the reshard # Reshard w2_in to a different core_grid configuration. Avoid using ttnn.reshard() due to incompatibility with trace mode - w2_in = ttnn.sharded_to_interleaved(w2_in, ttnn.L1_MEMORY_CONFIG) - w2_in = ttnn.interleaved_to_sharded(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"]) + w2_in = ttnn.reshard(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"]) ttnn.deallocate(w3_out) ttnn.deallocate(w1_out)