diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index cccbaf7864ef..f06e2ff63f1f 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -107,8 +107,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: mode == "decode" ): # TODO Add a check for a match between FF1/FF3 and FF2 memory configs. If they match avoid doing the reshard # Reshard w2_in to a different core_grid configuration. Avoid using ttnn.reshard() due to incompatibility with trace mode - w2_in = ttnn.sharded_to_interleaved(w2_in, ttnn.L1_MEMORY_CONFIG) - w2_in = ttnn.interleaved_to_sharded(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"]) + w2_in = ttnn.reshard(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"]) ttnn.deallocate(w3_out) ttnn.deallocate(w1_out)