diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 5b1b5a49b3bd..04cf2c8d77be 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -84,7 +84,8 @@ def forward( get_last_token=-1, ): # No-op if callers already provide the right memory config - x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"]) + if mode == "decode": + x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"]) for layer in self.layers: x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table)