From 27249e23bdabf584c77a71ed9aa48664a6bc6570 Mon Sep 17 00:00:00 2001 From: Salar Date: Tue, 24 Dec 2024 01:20:04 +0000 Subject: [PATCH] Fix minor bug with token input processing for galaxy Signed-off-by: Salar --- models/demos/llama3/tt/llama_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index db3cde85b33..14c9bdb0cf9 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -110,9 +110,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tokens = tokens.reshape(1, 1, 1, -1) S = tokens.shape[-1] - dims = (None, -1) if self.args.is_galaxy else (None, None) + dims = (None, None) # replicate mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape) - tokens = ttnn.from_torch( tokens, device=self.mesh_device,