diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 4e285c77d379..c75817ea3c5d 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -845,6 +845,11 @@ def parse_output_line(line, previous_line, current_status): speed_match = re.search(r"@ (\d+\.\d+) tok/s/user", line) if speed_match: speed = float(speed_match.group(1)) + else: + # Check for end_to_end_inference time from perf test + latency_match = re.search(r"end_to_end_inference: (\d+\.\d+)s", line) + if latency_match: + speed = 1000 * float(latency_match.group(1)) # convert to ms # Check for PCC information pcc = None diff --git a/models/demos/llama3/tests/test_llama_perf.py b/models/demos/llama3/tests/test_llama_perf.py index 0e4c35bf41f0..a6999269ca0b 100644 --- a/models/demos/llama3/tests/test_llama_perf.py +++ b/models/demos/llama3/tests/test_llama_perf.py @@ -36,6 +36,15 @@ (1024, 20), ), ) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat8_b @@ -162,15 +171,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos dtype=ttnn.uint32, ) - # Generate first input on host - pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) # Send first input to device - tt_decode_input = pt_decode_input - decode_input = tt_model.args.prepare_inputs_ttnn_decode( - tt_decode_input, - tt_model.args.model_config["DEC_SKIP_OUTPUT"], - ) - current_pos = ttnn.from_torch( torch.tensor([generation_start_pos] * batch), device=mesh_device, @@ -183,6 +184,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos profiler.start(f"model_run_for_inference_{i}") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) + decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"]) tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) ttnn.deallocate(tt_out)