diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index 24270d7d637d..d359a9f6afc3 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -176,7 +176,7 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, decode_input = model_args.prepare_inputs_ttnn_decode( tt_decode_input, - ttnn.DRAM_MEMORY_CONFIG, + model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"], ) current_pos_tensor = ttnn.from_torch( torch.tensor([current_pos] * batch), diff --git a/models/demos/llama3/tests/test_llama_rms_norm.py b/models/demos/llama3/tests/test_llama_rms_norm.py index 4d63f4fc1c2f..52a630096499 100644 --- a/models/demos/llama3/tests/test_llama_rms_norm.py +++ b/models/demos/llama3/tests/test_llama_rms_norm.py @@ -73,7 +73,9 @@ def test_llama_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, e dtype=dtype, layout=ttnn.TILE_LAYOUT, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), - memory_config=ttnn.L1_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG, + memory_config=model_args.get_model_config()["DEC_SKIP_OUTPUT_MEMCFG"] + if mode == "decode" + else ttnn.DRAM_MEMORY_CONFIG, ) tt_output = tt_model(tt_input, mode=mode)