Skip to content

Commit

Permalink
#0: Improve support for perf tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 20, 2024
1 parent 5e7141c commit 126bda2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
5 changes: 5 additions & 0 deletions models/demos/llama3/lt
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,11 @@ def parse_output_line(line, previous_line, current_status):
speed_match = re.search(r"@ (\d+\.\d+) tok/s/user", line)
if speed_match:
speed = float(speed_match.group(1))
else:
# Check for end_to_end_inference time from perf test
latency_match = re.search(r"end_to_end_inference: (\d+\.\d+)s", line)
if latency_match:
speed = 1000 * float(latency_match.group(1)) # convert to ms

# Check for PCC information
pcc = None
Expand Down
18 changes: 10 additions & 8 deletions models/demos/llama3/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@
(1024, 20),
),
)
@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_program_cache, reset_seeds, ensure_gc):
dtype = ttnn.bfloat8_b

Expand Down Expand Up @@ -162,15 +171,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos
dtype=ttnn.uint32,
)

# Generate first input on host
pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1)
# Send first input to device
tt_decode_input = pt_decode_input
decode_input = tt_model.args.prepare_inputs_ttnn_decode(
tt_decode_input,
tt_model.args.model_config["DEC_SKIP_OUTPUT"],
)

current_pos = ttnn.from_torch(
torch.tensor([generation_start_pos] * batch),
device=mesh_device,
Expand All @@ -183,6 +184,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos
profiler.start(f"model_run_for_inference_{i}")

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"])
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
ttnn.deallocate(tt_out)
Expand Down

0 comments on commit 126bda2

Please sign in to comment.