diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index db3cde85b338..14c9bdb0cf99 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -110,9 +110,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tokens = tokens.reshape(1, 1, 1, -1) S = tokens.shape[-1] - dims = (None, -1) if self.args.is_galaxy else (None, None) + dims = (None, None) # replicate mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape) - tokens = ttnn.from_torch( tokens, device=self.mesh_device,