Skip to content

Commit

Permalink
Fix minor bug with token input processing for galaxy
Browse files Browse the repository at this point in the history
Signed-off-by: Salar <[email protected]>
  • Loading branch information
skhorasganiTT committed Dec 24, 2024
1 parent d5e4b23 commit 38bc926
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag

tokens = tokens.reshape(1, 1, 1, -1)
S = tokens.shape[-1]
dims = (None, -1) if self.args.is_galaxy else (None, None)
dims = (None, None) # replicate
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape)

tokens = ttnn.from_torch(
tokens,
device=self.mesh_device,
Expand Down

0 comments on commit 38bc926

Please sign in to comment.