Skip to content

Commit

Permalink
#14273: Update tests for sharded residual
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 17, 2024
1 parent 5151767 commit 8a79a05
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,

decode_input = model_args.prepare_inputs_ttnn_decode(
tt_decode_input,
ttnn.DRAM_MEMORY_CONFIG,
model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"],
)
current_pos_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch),
Expand Down
4 changes: 3 additions & 1 deletion models/demos/llama3/tests/test_llama_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_llama_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, e
dtype=dtype,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
memory_config=ttnn.L1_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG,
memory_config=model_args.get_model_config()["DEC_SKIP_OUTPUT_MEMCFG"]
if mode == "decode"
else ttnn.DRAM_MEMORY_CONFIG,
)

tt_output = tt_model(tt_input, mode=mode)
Expand Down

0 comments on commit 8a79a05

Please sign in to comment.