Skip to content

Commit

Permalink
#0: Fix llama3 embedding aftr recent changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Dec 6, 2024
1 parent 317d346 commit 355f4d6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 2 deletions.
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache
layout=ttnn.ROW_MAJOR_LAYOUT,
)
tt_output = tt_emb(tt_input)
tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[0].view(
tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)).view(
reference_output.shape
)
logger.info(f"tt_output_torch: {tt_output_torch.shape}")
Expand Down
1 change: 0 additions & 1 deletion models/demos/llama3/tt/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ def __init__(

def forward(self, x: ttnn.Tensor) -> ttnn.Tensor:
x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
x = ttnn.reshape(x, [x.shape[0], 1, x.shape[1], x.shape[2]])
return x

0 comments on commit 355f4d6

Please sign in to comment.